diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index e2e6d09c29..ca54e72434 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -77,6 +77,16 @@ def setup_pytorch_extension( setup_mpi_flags(include_dirs, cxx_flags) + # Mirror the NCCL EP gate from setup.py / common CMake. When disabled, the + # ep.cpp source no-ops at the #ifdef boundary; without the define it would + # produce undefined references to nvte_ep_*. + if bool(int(os.getenv("NVTE_BUILD_WITH_NCCL_EP", "1"))): + cxx_flags.append("-DNVTE_WITH_NCCL_EP") + # PyTorch's symm-mem headers gate the NCCL_HAS_SYMMEM_* feature macros on + # USE_NCCL. The EP extension shares the symm-mem NCCL comm with torch, so + # it needs those macros visible. + cxx_flags.append("-DUSE_NCCL") + library_dirs = [] libraries = [] if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))): diff --git a/examples/pytorch/ep/bench/ep_bench.py b/examples/pytorch/ep/bench/ep_bench.py new file mode 100644 index 0000000000..81f5b83883 --- /dev/null +++ b/examples/pytorch/ep/bench/ep_bench.py @@ -0,0 +1,394 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""PyTorch EP perf bench: raw and autograd dispatch/combine on a single EP group. + +One process per GPU; launched via run_ep_bench.sh (torchrun). + +Stages (each timed in its own loop): + - dispatch_raw: _ep_dispatch_raw (no autograd, no prepare) + - ep_dispatch_fwd: ep_dispatch forward only + - ep_dispatch_fwd_bwd: ep_dispatch + backward on 0.5 * ||recv||^2 + - combine_raw: _ep_combine_raw (no autograd) + - ep_combine_fwd: ep_combine forward only + - ep_combine_fwd_bwd: ep_combine + backward + +ep_prepare runs once outside the timed loops. --kineto DIR dumps a Chrome +trace plus a per-kernel summary on rank 0. +""" + +import argparse +import gc +import os +import sys +import time +from contextlib import nullcontext + +import numpy as np +import torch +import torch.distributed as dist + +from transformer_engine.pytorch.ep import ( + EpBuffer, + ep_bootstrap, + ep_combine, + ep_dispatch, + ep_finalize, + ep_prepare, + _ep_combine_raw, + _ep_dispatch_raw, +) + + +def _parse_args(): + p = argparse.ArgumentParser(description="TE-PyTorch EP perf bench") + p.add_argument("--tokens-per-rank", type=int, default=8192) + p.add_argument("--hidden", type=int, default=7168) + p.add_argument("--top-k", type=int, default=8) + p.add_argument("--num-experts", type=int, default=256) + p.add_argument("--warmup", type=int, default=2) + p.add_argument("--iters", type=int, default=10) + p.add_argument( + "--max-num-sms", + type=int, + default=0, + help="Max SMs for dispatch/combine/preprocess kernels (0 = auto).", + ) + p.add_argument( + "--kineto", + default=None, + help="If set, dump a Kineto Chrome trace + per-kernel summary into this dir (rank 0).", + ) + p.add_argument( + "--cuda-graph", + action="store_true", + default=False, + help=( + "Capture each stage into a CUDA graph and time replay() instead of the eager call. " + "Raw + fwd-only stages use torch.cuda.graph; fwd+bwd stages use " + "torch.cuda.make_graphed_callables to capture forward and backward together." + ), + ) + p.add_argument( + "--mode-label", + default=None, + help="Optional suffix for NVTX range names (e.g. 'fused' / 'unfused').", + ) + return p.parse_args() + + +def _nvtx_funcs(): + """Return push/pop helpers using torch.cuda.nvtx if available.""" + try: + push = torch.cuda.nvtx.range_push + pop = torch.cuda.nvtx.range_pop + return push, pop + except AttributeError: + return lambda _name: None, lambda: None + + +def _device_sm() -> int: + major, minor = torch.cuda.get_device_capability() + return major * 10 + minor + + +def _make_inputs(rank, world_size, T, H, K, E, device): + """Round-robin identity routing + uniform top-k weights.""" + topk_idx = np.empty((T, K), dtype=np.int64) + for t in range(T): + for k in range(K): + topk_idx[t, k] = ((rank * T + t) * K + k) % E + rng = np.random.default_rng(seed=42 + rank) + tokens_np = (rng.standard_normal((T, H), dtype=np.float32) * 0.5).astype(np.float32) + return ( + torch.from_numpy(topk_idx).to(device), + torch.from_numpy(tokens_np).to(device=device, dtype=torch.bfloat16), + torch.full((T, K), 1.0 / K, dtype=torch.float32, device=device), + ) + + +def _time_stage_us(name, fn, iters, nvtx_suffix, push, pop): + """Time fn for iters iterations after one untimed warmup; returns mean us.""" + # Run iters+1 times; drop the first (autotune outlier) and frame NVTX from iter 1. + total_ns = 0 + counted = 0 + for i in range(iters + 1): + if i == 1: + push(f"{name}{nvtx_suffix}") + torch.cuda.synchronize() + t0 = time.perf_counter_ns() + fn() + torch.cuda.synchronize() + dt = time.perf_counter_ns() - t0 + if i == 0: + continue + total_ns += dt + counted += 1 + pop() + return total_ns / 1e3 / counted + + +def main(): + args = _parse_args() + dist.init_process_group(backend="nccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", rank))) + device = torch.device("cuda", torch.cuda.current_device()) + + if _device_sm() < 90: + if rank == 0: + print(f"[ep_bench] SKIPPED: EP requires SM>=90 (got SM{_device_sm()})") + dist.destroy_process_group() + return + if world_size < 4: + if rank == 0: + print(f"[ep_bench] SKIPPED: EP requires >=4 ranks (got {world_size})") + dist.destroy_process_group() + return + + ep_size = world_size + E = args.num_experts + assert E % ep_size == 0, f"num_experts ({E}) must be divisible by ep_size ({ep_size})" + num_local_experts = E // ep_size + T = args.tokens_per_rank + H = args.hidden + K = args.top_k + # Conservative cap: every token could land on every local expert. + recv_pr = world_size * T * K // 2 + if rank == 0: + print( + f"[ep_bench] world={world_size} ep={ep_size} T={T} H={H} K={K} " + f"E={E} (local={num_local_experts}) recv_pr={recv_pr}" + + (f" mode={args.mode_label}" if args.mode_label else ""), + flush=True, + ) + + ep_group = dist.new_group(ranks=list(range(world_size)), backend="nccl") + ep_bootstrap( + ep_group, + num_experts=E, + max_tokens_per_rank=T, + recv_capacity_per_rank=recv_pr, + hidden_dim=H, + max_num_sms=args.max_num_sms, + ) + + topk_idx, tokens_hbm, topk_w_hbm = _make_inputs(rank, world_size, T, H, K, E, device) + + buffer = EpBuffer( + top_k=K, + max_tokens_per_rank=T, + recv_capacity_per_rank=recv_pr, + hidden_dim=H, + num_local_experts=num_local_experts, + ) + + tokens = tokens_hbm + topk_w = topk_w_hbm + recv_tokens = torch.empty(recv_pr, H, dtype=torch.bfloat16, device=device) + recv_w = torch.empty(recv_pr, dtype=torch.float32, device=device) + + # -- Prepare once outside the timed loops ------------------------------ + ep_prepare(buffer, topk_idx) + torch.cuda.synchronize() + + # Pre-dispatch a steady recv_tokens / recv_w so combine stages have valid input. + _ep_dispatch_raw(buffer, topk_idx, tokens, topk_w, recv_tokens, recv_w) + torch.cuda.synchronize() + # fp-equivalent stand-in for an MLP output. + expert_out = recv_tokens.clone() + + nvtx_suffix = f"[{args.mode_label}]" if args.mode_label else "" + push, pop = _nvtx_funcs() + + # -- Stage closures ---------------------------------------------------- + # Persistent fwd+bwd inputs (make_graphed_callables needs stable storage). + tokens_p = tokens.detach().clone().requires_grad_(True) + eo_p = recv_tokens.detach().clone().requires_grad_(True) + + # Stand-in callables; the cuda-graph branch below swaps in graphed versions. + fwd_bwd_dispatch_fn = lambda x: ep_dispatch(buffer, x, topk_idx, topk_w)[0] # noqa: E731 + fwd_bwd_combine_fn = lambda eo: ep_combine(buffer, eo) # noqa: E731 + + def _dispatch_raw(): + _ep_dispatch_raw(buffer, topk_idx, tokens, topk_w, recv_tokens, recv_w) + + def _combine_raw(): + out_buf = torch.empty(T, H, dtype=torch.bfloat16, device=device) + _ep_combine_raw(buffer, expert_out, out_buf) + + def _ep_dispatch_fwd(): + ep_dispatch(buffer, tokens.detach(), topk_idx, topk_w) + + def _ep_dispatch_fwd_bwd(): + tokens_p.grad = None + r = fwd_bwd_dispatch_fn(tokens_p) + (0.5 * (r * r).sum(dtype=torch.float32)).backward() + + def _ep_combine_fwd(): + ep_combine(buffer, recv_tokens) + + def _ep_combine_fwd_bwd(): + eo_p.grad = None + out = fwd_bwd_combine_fn(eo_p) + (0.5 * (out * out).sum(dtype=torch.float32)).backward() + + stages = [ + ("dispatch_raw", _dispatch_raw, True), + ("ep_dispatch_fwd", _ep_dispatch_fwd, True), + ("ep_dispatch_fwd_bwd", _ep_dispatch_fwd_bwd, False), + ("combine_raw", _combine_raw, True), + ("ep_combine_fwd", _ep_combine_fwd, True), + ("ep_combine_fwd_bwd", _ep_combine_fwd_bwd, False), + ] + # Third tuple element: True = direct torch.cuda.graph capture; False = use + # make_graphed_callables (autograd-aware) instead. + + # -- Warmup ----------------------------------------------------------- + for _ in range(args.warmup): + for _name, fn, _capt in stages: + fn() + torch.cuda.synchronize() + + # -- Optional CUDA-graph capture -------------------------------------- + # Capture each capturable stage on a side stream and time .replay() + # instead of the eager call. Outputs allocated inside the + # autograd.Function's forward go through the per-capture private pool + # so addresses stay stable across replays. + captured_runners = {} + if args.cuda_graph: + # Graph fwd+bwd of the autograd-wrapped ops via make_graphed_callables. + class _DispatchMod(torch.nn.Module): + def forward(self, x): + return ep_dispatch(buffer, x, topk_idx, topk_w)[0] + + class _CombineMod(torch.nn.Module): + def forward(self, eo): + return ep_combine(buffer, eo) + + disp_mod = _DispatchMod().cuda() + comb_mod = _CombineMod().cuda() + g_disp, g_comb = torch.cuda.make_graphed_callables( + (disp_mod, comb_mod), + ((tokens_p,), (eo_p,)), + ) + fwd_bwd_dispatch_fn = g_disp + fwd_bwd_combine_fn = g_comb + + # Direct torch.cuda.graph capture for raw + fwd-only stages. + side = torch.cuda.Stream() + side.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(side): + for name, fn, direct_capturable in stages: + if not direct_capturable: + continue + fn() # prime the allocator for stable replay addresses + torch.cuda.synchronize() + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + fn() + captured_runners[name] = g + torch.cuda.current_stream().wait_stream(side) + torch.cuda.synchronize() + + # -- Optional Kineto profiling ---------------------------------------- + kineto_ctx = nullcontext() + if args.kineto and rank == 0: + os.makedirs(args.kineto, exist_ok=True) + kineto_ctx = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=False, + with_stack=False, + ) + + # -- Timed loops ------------------------------------------------------ + results = {} + with kineto_ctx as prof: + for name, fn, _ in stages: + runner = fn + if name in captured_runners: + # Time replay() instead of the eager call. + graph = captured_runners[name] + runner = graph.replay + results[name] = _time_stage_us(name, runner, args.iters, nvtx_suffix, push, pop) + + if rank == 0: + label = f" [{args.mode_label}]" if args.mode_label else "" + print("", flush=True) + print(f"| stage | mean wall (us){label} |", flush=True) + print("|----------------------|---------------:|", flush=True) + for name in ( + "dispatch_raw", + "ep_dispatch_fwd", + "ep_dispatch_fwd_bwd", + "combine_raw", + "ep_combine_fwd", + "ep_combine_fwd_bwd", + ): + print(f"| {name:20s} | {results[name]:14.1f} |", flush=True) + print( + "| (dispatch fwd-raw) |" + f" {results['ep_dispatch_fwd'] - results['dispatch_raw']:14.1f} |", + flush=True, + ) + print( + "| (dispatch bwd-fwd) |" + f" {results['ep_dispatch_fwd_bwd'] - results['ep_dispatch_fwd']:14.1f} |", + flush=True, + ) + print( + "| (combine fwd-raw) |" + f" {results['ep_combine_fwd'] - results['combine_raw']:14.1f} |", + flush=True, + ) + print( + "| (combine bwd-fwd) |" + f" {results['ep_combine_fwd_bwd'] - results['ep_combine_fwd']:14.1f} |", + flush=True, + ) + print("", flush=True) + + if args.kineto and rank == 0 and prof is not None: + trace_path = os.path.join(args.kineto, "ep_bench_trace.json") + prof.export_chrome_trace(trace_path) + print(f"[ep_bench] kineto trace: {trace_path}", flush=True) + print( + prof.key_averages().table(sort_by="cuda_time_total", row_limit=30), + flush=True, + ) + kern_csv = os.path.join(args.kineto, "ep_bench_kernels.csv") + with open(kern_csv, "w") as f: + f.write("name,cuda_time_us,cpu_time_us,count\n") + for evt in prof.key_averages(): + if evt.device_time_total == 0 and evt.cpu_time_total == 0: + continue + f.write(f"{evt.key},{evt.device_time_total},{evt.cpu_time_total},{evt.count}\n") + print(f"[ep_bench] per-kernel CSV: {kern_csv}", flush=True) + + # Captured CUDA graphs (when --cuda-graph) hold references to NCCL EP + # handles and per-pool streams; drop them and sync before ep_finalize, + # otherwise the post-finalize dist.barrier can deadlock against pending + # graph state. + torch.cuda.synchronize() + if args.cuda_graph: + fwd_bwd_dispatch_fn = None + fwd_bwd_combine_fn = None + captured_runners.clear() + del g_disp, g_comb, disp_mod, comb_mod + del tokens_p, eo_p, buffer, recv_tokens, recv_w, tokens, topk_w, expert_out + gc.collect() + torch.cuda.synchronize() + # Release NCCL EP's borrowed comm before torch destroys it. + ep_finalize() + dist.barrier() + dist.destroy_process_group() + sys.stdout.flush() + sys.stderr.flush() + + +if __name__ == "__main__": + main() diff --git a/examples/pytorch/ep/bench/run_ep_bench.sh b/examples/pytorch/ep/bench/run_ep_bench.sh new file mode 100755 index 0000000000..fefecd7fa9 --- /dev/null +++ b/examples/pytorch/ep/bench/run_ep_bench.sh @@ -0,0 +1,72 @@ +#!/usr/bin/env bash +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +# +# Launcher for examples/pytorch/ep/bench/ep_bench.py. +# Examples: +# bash run_ep_bench.sh # plain run, stdout only +# bash run_ep_bench.sh --cuda-graph # capture + replay each stage as a CUDA graph +# bash run_ep_bench.sh --kineto # Chrome trace + per-kernel CSV (rank 0) +# bash run_ep_bench.sh --nsys # nsys profile on rank 0 -> results/pyt_nsys.nsys-rep + +set -uo pipefail + +NSYS=0; KINETO=0; CGRAPH=0 +for a in "$@"; do + case "$a" in + --nsys) NSYS=1 ;; + --kineto) KINETO=1 ;; + --cuda-graph) CGRAPH=1 ;; + *) echo "unknown arg: $a" >&2; exit 2 ;; + esac +done +if [ "${NSYS}" -eq 1 ] && [ "${KINETO}" -eq 1 ]; then + echo "--nsys and --kineto both attach CUPTI; pick one." >&2; exit 2 +fi + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TE_REPO_ROOT="$(cd "${SCRIPT_DIR}/../../../.." && pwd)" +RESULTS="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS}" +export PYTHONPATH="${TE_REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +DETECTED_GPUS=$(nvidia-smi -L 2>/dev/null | wc -l) +NUM_GPUS="${NUM_GPUS:-${DETECTED_GPUS}}" +if [ "${NUM_GPUS}" -lt 4 ]; then + echo "EP bench requires >=4 GPUs (found ${NUM_GPUS}); SKIPPING."; exit 0 +fi +if [ "${NUM_GPUS}" -gt 8 ]; then NUM_GPUS=8; fi + +: "${TIMEOUT_S:=1800}" +: "${NCCL_EP_JIT_CACHE_DIR:=${TMPDIR:-/tmp}/nccl_ep_jit_cache_$(id -u)}" +export NCCL_EP_JIT_CACHE_DIR +mkdir -p "${NCCL_EP_JIT_CACHE_DIR}" + +EXTRA_ARGS=() +TAG="pyt" +[ "${CGRAPH}" -eq 1 ] && EXTRA_ARGS+=(--cuda-graph) && TAG="${TAG}_cg" +if [ "${KINETO}" -eq 1 ]; then + EXTRA_ARGS+=(--kineto "${RESULTS}/kineto_${TAG}") +fi + +EP_BENCH_EXTRA_FLAGS="${EP_BENCH_EXTRA_FLAGS:-}" +LAUNCH=(torchrun --standalone --nnodes=1 --nproc-per-node="${NUM_GPUS}" + "${SCRIPT_DIR}/ep_bench.py" "${EXTRA_ARGS[@]}" ${EP_BENCH_EXTRA_FLAGS}) + +if [ "${NSYS}" -eq 1 ]; then + NSYS_CMD=(nsys profile + --output "${RESULTS}/pyt_${TAG}_nsys" + --force-overwrite=true + --trace=cuda,nvtx + --gpu-metrics-devices=none + --cuda-um-cpu-page-faults=false + --cuda-um-gpu-page-faults=false) + echo "[run_ep_bench] launching with nsys (results/${TAG}_nsys.nsys-rep)" + timeout --foreground --signal=TERM "${TIMEOUT_S}" "${NSYS_CMD[@]}" "${LAUNCH[@]}" + RC=$? +else + timeout --foreground --signal=TERM "${TIMEOUT_S}" "${LAUNCH[@]}" + RC=$? +fi +exit $RC diff --git a/examples/pytorch/ep/bench/run_nccl_ep_bench.sh b/examples/pytorch/ep/bench/run_nccl_ep_bench.sh new file mode 100755 index 0000000000..8f6da04a00 --- /dev/null +++ b/examples/pytorch/ep/bench/run_nccl_ep_bench.sh @@ -0,0 +1,62 @@ +#!/usr/bin/env bash +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +# +# Launcher for the native NCCL EP ``ep_bench`` (baseline for PyTorch comparison). +# Usage: +# bash run_nccl_ep_bench.sh # plain run, stdout only +# bash run_nccl_ep_bench.sh --nsys # nsys → results/nccl_ep_nsys.nsys-rep + +set -uo pipefail + +NSYS=0 +for a in "$@"; do + case "$a" in + --nsys) NSYS=1 ;; + *) echo "unknown arg: $a" >&2; exit 2 ;; + esac +done + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TE_REPO_ROOT="$(cd "${SCRIPT_DIR}/../../../.." && pwd)" +RESULTS="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS}" + +BIN="${TE_REPO_ROOT}/3rdparty/nccl/build/test/nccl_ep/ep_bench" +LIB="${TE_REPO_ROOT}/3rdparty/nccl/build/lib" +[ -x "${BIN}" ] || { echo "ep_bench not built at ${BIN}" >&2; exit 2; } + +NUM_GPUS=$(nvidia-smi -L 2>/dev/null | wc -l) +if [ "${NUM_GPUS}" -lt 4 ]; then + echo "NCCL EP bench requires >=4 GPUs (found ${NUM_GPUS}); SKIPPING."; exit 0 +fi +if [ "${NUM_GPUS}" -gt 8 ]; then NUM_GPUS=8; fi + +if [ "${NSYS}" -eq 1 ]; then + ITERS=10 +else + ITERS=50 +fi +ARGS=(--algorithm ht --layout em --tokens 2048 --hidden 7168 --top-k 8 + --experts 256 --warmup 5 --iters "${ITERS}") +[ "${NSYS}" -eq 1 ] && ARGS+=(--profile) # enables NVTX ranges + cudaProfilerStart/Stop + +CMD=(/usr/local/mpi/bin/mpirun --allow-run-as-root --oversubscribe -np "${NUM_GPUS}" + -x LD_LIBRARY_PATH="${LIB}:${LD_LIBRARY_PATH:-}" + "${BIN}" "${ARGS[@]}") + +if [ "${NSYS}" -eq 1 ]; then + CMD=(nsys profile + --output "${RESULTS}/nccl_ep_nsys" + --force-overwrite=true + --capture-range=cudaProfilerApi + --capture-range-end=stop + --trace=cuda,nvtx,osrt + "${CMD[@]}") +fi + +[ "${NSYS}" -eq 1 ] && SUFFIX="_nsys" || SUFFIX="" +LOG="${RESULTS}/stdout_nccl_ep${SUFFIX}.txt" +"${CMD[@]}" 2>&1 | tee "${LOG}" +echo "Done. Log: ${LOG}" diff --git a/examples/pytorch/ep/ep_moe.py b/examples/pytorch/ep/ep_moe.py new file mode 100644 index 0000000000..f72912301b --- /dev/null +++ b/examples/pytorch/ep/ep_moe.py @@ -0,0 +1,229 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""End-to-end MoE example: dispatch -> batched expert linear -> combine, fwd + bwd. + +One process per GPU; launched via run_test_ep.sh (torchrun). +""" + +import argparse +import os +import sys + +import numpy as np +import torch +import torch.distributed as dist + +from transformer_engine.pytorch.ep import ( + EpBuffer, + ep_bootstrap, + ep_combine, + ep_dispatch, + ep_finalize, +) + + +def _parse_args(): + p = argparse.ArgumentParser(description="TE-PyTorch EP MoE example (fwd + bwd)") + p.add_argument("--num-tokens", type=int, default=8, help="Per-rank token count.") + p.add_argument("--top-k", type=int, default=2) + p.add_argument("--hidden", type=int, default=32) + p.add_argument("--hidden-out", type=int, default=32) + p.add_argument("--num-experts", type=int, default=None) + p.add_argument("--check", action="store_true", default=True) + p.add_argument( + "--benchmark", + action="store_true", + help="Time fwd over HBM buffers.", + ) + p.add_argument("--benchmark-iters", type=int, default=20) + p.add_argument("--benchmark-warmup", type=int, default=5) + return p.parse_args() + + +def _make_routing(rank, T, K, E, num_local_experts): + """Deterministic routing: topk_idx[t, k] = (rank*NLE + t*K + k) % E.""" + topk_idx = np.empty((T, K), dtype=np.int64) + for t in range(T): + for k in range(K): + topk_idx[t, k] = (rank * num_local_experts + t * K + k) % E + return topk_idx + + +def _batched_expert_linear(recv_tokens, kernels, num_local_experts): + """Per-expert linear via bmm; ``recv_pr // num_local_experts`` slots per expert.""" + recv_pr, _H = recv_tokens.shape + H_out = kernels.shape[-1] + slots_per_expert = recv_pr // num_local_experts + grouped = recv_tokens.view(num_local_experts, slots_per_expert, recv_tokens.shape[-1]) + out = torch.bmm(grouped, kernels.to(grouped.dtype)) + return out.view(recv_pr, H_out) + + +def _reference_moe(tokens, topk_idx, topk_w, kernels): + T, K = topk_idx.shape + H_out = kernels.shape[-1] + out = np.zeros((T, H_out), dtype=np.float32) + for t in range(T): + tok = tokens[t].astype(np.float32) + for k in range(K): + e = int(topk_idx[t, k]) + out[t] += float(topk_w[t, k]) * (tok @ kernels[e].astype(np.float32)) + return out + + +def _reference_grad(tokens, topk_idx, topk_w, kernels): + T, K = topk_idx.shape + H = tokens.shape[-1] + ref_out = _reference_moe(tokens, topk_idx, topk_w, kernels) + grad = np.zeros((T, H), dtype=np.float32) + for t in range(T): + mixed = np.zeros_like(kernels[0]) + for k in range(K): + mixed = mixed + float(topk_w[t, k]) * kernels[int(topk_idx[t, k])] + grad[t] = ref_out[t] @ mixed.T + return ref_out, grad + + +def main(): + args = _parse_args() + + dist.init_process_group(backend="nccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", rank))) + device = torch.device("cuda", torch.cuda.current_device()) + + major, minor = torch.cuda.get_device_capability() + if major * 10 + minor < 90: + if rank == 0: + print(f"[ep_moe] SKIPPED: EP requires SM>=90 (got SM{major}{minor})") + dist.destroy_process_group() + return + + if world_size < 4: + if rank == 0: + print(f"[ep_moe] SKIPPED: EP requires >= 4 ranks (got {world_size})") + dist.destroy_process_group() + return + + ep_size = world_size + num_experts = args.num_experts if args.num_experts is not None else world_size + assert num_experts % ep_size == 0 + num_local_experts = num_experts // ep_size + T = args.num_tokens + recv_pr = ep_size * T * args.top_k + + ep_group = dist.new_group(ranks=list(range(world_size)), backend="nccl") + ep_bootstrap( + ep_group, + num_experts=num_experts, + max_tokens_per_rank=T, + recv_capacity_per_rank=recv_pr, + hidden_dim=args.hidden, + ) + try: + _run_layer( + args, rank, world_size, ep_size, num_experts, num_local_experts, T, recv_pr, device + ) + finally: + ep_finalize() + dist.destroy_process_group() + + +def _run_layer(args, rank, world_size, ep_size, num_experts, num_local_experts, T, recv_pr, device): + rng = np.random.default_rng(seed=42 + rank) + tokens_np = (rng.standard_normal((T, args.hidden), dtype=np.float32) * 0.5).astype(np.float32) + topk_idx_np = _make_routing(rank, T, args.top_k, num_experts, num_local_experts) + w_np = np.full((T, args.top_k), 1.0 / args.top_k, dtype=np.float32) + # Same seed across ranks -> identical kernel array everywhere. + kr = np.random.default_rng(seed=42) + kernels_np = ( + kr.standard_normal((num_experts, args.hidden, args.hidden_out), dtype=np.float32) + * (1.0 / np.sqrt(args.hidden)) + ).astype(np.float32) + + tokens = ( + torch.from_numpy(tokens_np).to(device=device, dtype=torch.bfloat16).requires_grad_(True) + ) + topk_idx = torch.from_numpy(topk_idx_np).to(device) + topk_w = torch.from_numpy(w_np).to(device) + kernels_local = torch.from_numpy( + kernels_np[rank * num_local_experts : (rank + 1) * num_local_experts] + ).to(device=device, dtype=torch.bfloat16) + + buffer = EpBuffer( + top_k=args.top_k, + max_tokens_per_rank=T, + recv_capacity_per_rank=recv_pr, + hidden_dim=args.hidden, + num_local_experts=num_local_experts, + ) + + recv_t, recv_w_out, _tc = ep_dispatch(buffer, tokens, topk_idx, topk_w) + expert_out = _batched_expert_linear(recv_t, kernels_local, num_local_experts) + # Apply per-slot topk weighting before combine. + expert_out = expert_out * recv_w_out.unsqueeze(-1).to(expert_out.dtype) + out = ep_combine(buffer, expert_out) + + loss = 0.5 * (out.float() ** 2).sum() + loss.backward() + torch.cuda.synchronize() + + if rank == 0: + print( + f"[ep_moe] loss={loss.item():.4f} grad_tokens.shape={tuple(tokens.grad.shape)} " + f"ep={ep_size} num_experts={num_experts} recv_pr={recv_pr}" + ) + + if args.benchmark: + # Time dispatch + expert + combine over HBM buffers. + import time + + torch.cuda.synchronize() + dist.barrier() + for _ in range(args.benchmark_warmup): + rt, rw, _tc = ep_dispatch(buffer, tokens.detach(), topk_idx, topk_w) + eo = _batched_expert_linear(rt, kernels_local, num_local_experts) + eo = eo * rw.unsqueeze(-1).to(eo.dtype) + ep_combine(buffer, eo) + torch.cuda.synchronize() + dist.barrier() + t0 = time.perf_counter() + for _ in range(args.benchmark_iters): + rt, rw, _tc = ep_dispatch(buffer, tokens.detach(), topk_idx, topk_w) + eo = _batched_expert_linear(rt, kernels_local, num_local_experts) + eo = eo * rw.unsqueeze(-1).to(eo.dtype) + ep_combine(buffer, eo) + torch.cuda.synchronize() + dt_ms = (time.perf_counter() - t0) * 1000.0 / args.benchmark_iters + if rank == 0: + print(f"[ep_moe --benchmark] HBM: {dt_ms:.3f} ms/iter (iters={args.benchmark_iters})") + + if args.check: + # All-gather inputs/outputs/grads for a global reference comparison. + global_tokens = [torch.empty_like(tokens) for _ in range(world_size)] + global_topk_idx = [torch.empty_like(topk_idx) for _ in range(world_size)] + global_topk_w = [torch.empty_like(topk_w) for _ in range(world_size)] + global_out = [torch.empty_like(out) for _ in range(world_size)] + global_grad = [torch.empty_like(tokens.grad) for _ in range(world_size)] + dist.all_gather(global_tokens, tokens.detach()) + dist.all_gather(global_topk_idx, topk_idx) + dist.all_gather(global_topk_w, topk_w) + dist.all_gather(global_out, out.detach()) + dist.all_gather(global_grad, tokens.grad) + if rank == 0: + all_tokens = torch.cat(global_tokens).float().cpu().numpy() + all_idx = torch.cat(global_topk_idx).cpu().numpy() + all_w = torch.cat(global_topk_w).cpu().numpy() + all_out = torch.cat(global_out).float().cpu().numpy() + all_grad = torch.cat(global_grad).float().cpu().numpy() + ref_out, ref_grad = _reference_grad(all_tokens, all_idx, all_w, kernels_np) + np.testing.assert_allclose(all_out, ref_out, rtol=5e-2, atol=5e-2) + np.testing.assert_allclose(all_grad, ref_grad, rtol=5e-2, atol=5e-2) + print(f"[ep_moe] --check PASSED (ref_out.sum()={float(ref_out.sum()):.4f})") + + +if __name__ == "__main__": + main() + sys.exit(0) diff --git a/examples/pytorch/ep/run_test_ep.sh b/examples/pytorch/ep/run_test_ep.sh new file mode 100755 index 0000000000..13b41f4cb2 --- /dev/null +++ b/examples/pytorch/ep/run_test_ep.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +set -uo pipefail + +DETECTED_GPUS=$(nvidia-smi -L 2>/dev/null | wc -l) +NUM_GPUS="${NUM_GPUS:-${DETECTED_GPUS}}" +if [ "${NUM_GPUS}" -lt 4 ]; then + echo "EP requires >= 4 GPUs (found ${NUM_GPUS}); SKIPPING." + exit 0 +fi +if [ "${NUM_GPUS}" -gt 8 ]; then NUM_GPUS=8; fi + +: ${TE_PATH:=/opt/transformerengine} +: ${TEST_TIMEOUT_S:=120} + +SCRIPT="${TE_PATH}/examples/pytorch/ep/ep_moe.py" +export PYTHONPATH="${TE_PATH}${PYTHONPATH:+:${PYTHONPATH}}" + +# Stage JIT cubins on tmpfs for fast iteration. +: ${NCCL_EP_JIT_CACHE_DIR:="${TMPDIR:-/tmp}/nccl_ep_jit_cache_$(id -u)"} +export NCCL_EP_JIT_CACHE_DIR +mkdir -p "$NCCL_EP_JIT_CACHE_DIR" + +echo "*** Executing ep_moe.py across ${NUM_GPUS} GPUs (timeout=${TEST_TIMEOUT_S}s) ***" +timeout --foreground --signal=KILL "${TEST_TIMEOUT_S}" \ + torchrun --standalone --nnodes=1 --nproc-per-node="${NUM_GPUS}" \ + "${SCRIPT}" --check 2>&1 | tee stdout_ep_moe.txt +RC=${PIPESTATUS[0]} + +RET=0 +if [ "${RC}" -ne 0 ]; then RET=1; fi +if grep -qE "(^|]:)FAILED|(^|]:)Traceback" stdout_ep_moe.txt; then RET=1; fi +rm -f stdout_ep_moe.txt +exit $RET diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 7eb34a62e4..50a51353d1 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -50,6 +50,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_use python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_newton_schulz.xml $TE_PATH/tests/pytorch/distributed/test_newton_schulz.py || test_fail "test_newton_schulz.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_ep.xml $TE_PATH/tests/pytorch/distributed/test_ep.py || test_fail "test_ep.py" # debug tests diff --git a/tests/pytorch/distributed/run_ep.py b/tests/pytorch/distributed/run_ep.py new file mode 100644 index 0000000000..b09071a57c --- /dev/null +++ b/tests/pytorch/distributed/run_ep.py @@ -0,0 +1,371 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Multi-process PyTorch EP tests, launched via torchrun (one process per GPU).""" + +import os +import sys +import unittest + +import numpy as np +import torch +import torch.distributed as dist + +from transformer_engine.pytorch.ep import ( + EpBuffer, + ep_bootstrap, + ep_finalize, + ep_prepare, + ep_dispatch, + ep_combine, + symm_mem_alloc, + _ep_combine_raw, + _ep_dispatch_raw, +) + + +ZERO_COPY = False + +# Must come after the transformer_engine import so libtransformer_engine.so is loaded. +import transformer_engine_torch as tex # noqa: F401 + + +NUM_LOCAL_EXPERTS = 2 +HIDDEN_DIM = 32 +TOP_K = 2 +TOKENS_PER_RANK = 4 + + +def _device_sm() -> int: + major, minor = torch.cuda.get_device_capability() + return major * 10 + minor + + +def _build_ep_group(): + """EP group spanning all ranks of the default PG.""" + world_pg = dist.distributed_c10d._get_default_group() + ranks = list(range(world_pg.size())) + return dist.new_group(ranks=ranks, backend="nccl") + + +def _make_identity_inputs(rank, ep_size, device="cuda"): + """Per-rank identity routing + uniform weights so combine matches tokens.""" + T = TOKENS_PER_RANK + E = ep_size * NUM_LOCAL_EXPERTS + topk_idx = np.empty((T, TOP_K), dtype=np.int64) + base = rank * T + for t in range(T): + for k in range(TOP_K): + topk_idx[t, k] = ((base + t) * TOP_K + k) % E + tokens_np = np.linspace( + 0.1 + rank * 0.01, 0.9 + rank * 0.01, T * HIDDEN_DIM, dtype=np.float32 + ).reshape(T, HIDDEN_DIM) + topk_weights = np.full((T, TOP_K), 1.0 / TOP_K, dtype=np.float32) + return ( + torch.from_numpy(topk_idx).to(device), + torch.from_numpy(tokens_np).to(device=device, dtype=torch.bfloat16), + torch.from_numpy(topk_weights).to(device), + ) + + +class _Cfg: + rank: int + world_size: int + ep_size: int + num_experts: int + recv_capacity_per_rank: int + device: torch.device + + +def _make_cfg() -> _Cfg: + cfg = _Cfg() + cfg.rank = dist.get_rank() + cfg.world_size = dist.get_world_size() + cfg.ep_size = cfg.world_size + cfg.num_experts = NUM_LOCAL_EXPERTS * cfg.ep_size + T = TOKENS_PER_RANK + active = min(cfg.num_experts, T * cfg.ep_size * TOP_K) + overconc = cfg.num_experts // active + cfg.recv_capacity_per_rank = NUM_LOCAL_EXPERTS * max(T * cfg.ep_size * TOP_K, 16) * overconc * 2 + cfg.device = torch.device("cuda", torch.cuda.current_device()) + return cfg + + +class TestEP(unittest.TestCase): + cfg: _Cfg + ep_group: dist.ProcessGroup + + @classmethod + def setUpClass(cls): + if _device_sm() < 90: + raise unittest.SkipTest(f"NCCL EP requires SM>=90 (got SM{_device_sm()})") + cls.cfg = _make_cfg() + cls.ep_group = _build_ep_group() + ep_bootstrap( + cls.ep_group, + num_experts=cls.cfg.num_experts, + max_tokens_per_rank=TOKENS_PER_RANK, + recv_capacity_per_rank=cls.cfg.recv_capacity_per_rank, + hidden_dim=HIDDEN_DIM, + zero_copy=ZERO_COPY, + ) + + def _make_buffer(self, alignment=0, top_k=TOP_K): + return EpBuffer( + top_k=top_k, + max_tokens_per_rank=TOKENS_PER_RANK, + recv_capacity_per_rank=self.cfg.recv_capacity_per_rank, + hidden_dim=HIDDEN_DIM, + num_local_experts=NUM_LOCAL_EXPERTS, + alignment=alignment, + ) + + def _make_raw_recv(self, dtype=torch.bfloat16): + """Raw recv tensors + token_counts for the primitive tests.""" + rc = self.cfg.recv_capacity_per_rank + return ( + torch.empty(rc, HIDDEN_DIM, dtype=dtype, device=self.cfg.device), + torch.empty(rc, dtype=torch.float32, device=self.cfg.device), + torch.empty(NUM_LOCAL_EXPERTS, dtype=torch.int32, device=self.cfg.device), + ) + + @staticmethod + def _weighted(recv_tokens, recv_w): + """fp32 per-slot weighting + cast back; matches the upstream combine input.""" + mask = (recv_w != 0).to(torch.float32).unsqueeze(-1) + return (recv_tokens.float() * recv_w.unsqueeze(-1).float() * mask).to(recv_tokens.dtype) + + def _moe_step(self, buffer, topk_idx, tokens, w): + recv_t, recv_w_out, _tc = ep_dispatch(buffer, tokens, topk_idx, w) + eo = self._weighted(recv_t, recv_w_out) + return ep_combine(buffer, eo) + + # Prepare + + def test_primitive_prepare(self): + buf = self._make_buffer() + topk_idx, _toks, _w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + token_counts = ep_prepare(buf, topk_idx) + torch.cuda.synchronize() + self.assertEqual(token_counts.shape, (NUM_LOCAL_EXPERTS,)) + local = int(token_counts.sum().item()) + total = torch.tensor([local], dtype=torch.int64, device=self.cfg.device) + dist.all_reduce(total, op=dist.ReduceOp.SUM, group=self.ep_group) + self.assertEqual(int(total.item()), self.cfg.world_size * TOKENS_PER_RANK * TOP_K) + + # Identity round-trip via raw primitives + + def test_primitive_dispatch_combine_identity(self): + buf = self._make_buffer() + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + recv_tokens, recv_w, _ = self._make_raw_recv() + ep_prepare(buf, topk_idx) + _ep_dispatch_raw(buf, topk_idx, tokens, w, recv_tokens, recv_w) + result = torch.empty_like(tokens) + _ep_combine_raw(buf, self._weighted(recv_tokens, recv_w), result) + torch.cuda.synchronize() + torch.testing.assert_close(result.float(), tokens.float(), atol=5e-2, rtol=5e-2) + + # Autograd + + def test_dispatch_fwd_bwd(self): + """0.5*||recv_tokens||^2 ; grad_tokens equals TOP_K * tokens.""" + buf = self._make_buffer() + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + tokens_p = tokens.detach().clone().requires_grad_(True) + recv_t, recv_w, _tc = ep_dispatch(buf, tokens_p, topk_idx, w) + # Pull recv_w into the loss with a zero scale so both dispatch outputs + # contribute a (possibly-zero) grad — backward respects user-supplied + # grad inputs and won't fabricate Nones into zeros. + loss = 0.5 * (recv_t.float() ** 2).sum() + 0.0 * recv_w.float().sum() + loss.backward() + torch.cuda.synchronize() + torch.testing.assert_close( + tokens_p.grad.float(), tokens.float() * float(TOP_K), atol=5e-2, rtol=5e-2 + ) + + def test_combine_fwd_bwd(self): + """Full dispatch + combine fwd+bwd; identity inputs round-trip.""" + buf = self._make_buffer() + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + tokens_p = tokens.detach().clone().requires_grad_(True) + out = self._moe_step(buf, topk_idx, tokens_p, w) + loss = 0.5 * (out.float() ** 2).sum() + loss.backward() + torch.cuda.synchronize() + torch.testing.assert_close(out.float(), tokens.float(), atol=5e-2, rtol=5e-2) + torch.testing.assert_close(tokens_p.grad.float(), tokens.float(), atol=5e-2, rtol=5e-2) + + # Multi-iter stability + + def test_dispatch_fwd_bwd_multiple_iterations(self): + """5 fwd+bwd iters on the same EpBuffer must be bit-stable.""" + buf = self._make_buffer() + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + + def one_step(): + tokens_p = tokens.detach().clone().requires_grad_(True) + out = self._moe_step(buf, topk_idx, tokens_p, w) + loss = 0.5 * (out.float() ** 2).sum() + loss.backward() + return out.detach().clone(), tokens_p.grad.detach().clone() + + out_ref, grad_ref = one_step() + torch.cuda.synchronize() + for _ in range(4): + out_i, grad_i = one_step() + torch.cuda.synchronize() + torch.testing.assert_close(out_i, out_ref, atol=0, rtol=0) + torch.testing.assert_close(grad_i, grad_ref, atol=0, rtol=0) + + # CUDA graph + + def test_cuda_graph_capture(self): + """Capture raw dispatch+combine into a CUDA graph; replay must be bit-stable.""" + buf = self._make_buffer() + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + recv_tokens, recv_w, _ = self._make_raw_recv() + result = torch.empty_like(tokens) + + def step(): + ep_prepare(buf, topk_idx) + _ep_dispatch_raw(buf, topk_idx, tokens, w, recv_tokens, recv_w) + _ep_combine_raw(buf, self._weighted(recv_tokens, recv_w), result) + + for _ in range(3): + step() + torch.cuda.synchronize() + + # Routing is fixed per layer; prepare runs once before capture. + ep_prepare(buf, topk_idx) + torch.cuda.synchronize() + + graph = torch.cuda.CUDAGraph() + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + with torch.cuda.graph(graph): + _ep_dispatch_raw(buf, topk_idx, tokens, w, recv_tokens, recv_w) + _ep_combine_raw(buf, self._weighted(recv_tokens, recv_w), result) + torch.cuda.current_stream().wait_stream(s) + torch.cuda.synchronize() + + ref = result.clone() + for _ in range(5): + graph.replay() + torch.cuda.synchronize() + torch.testing.assert_close(result.float(), ref.float(), atol=0, rtol=0) + + # PP-1F1B handle isolation + + def test_pp_1f1b_two_handles(self): + """PP-1F1B interleave (F0 F1 B0 F2 B1 B2) over 3 per-microbatch buffers.""" + T, H = TOKENS_PER_RANK, HIDDEN_DIM + idx, _toks, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + scales = (0.13, 0.41, 0.77) + buffers, tokens, tokens_p = [], [], [] + for s in scales: + buffers.append(self._make_buffer()) + t = torch.full( + (T, H), s + self.cfg.rank * 0.01, dtype=torch.bfloat16, device=self.cfg.device + ) + tokens.append(t) + tokens_p.append(t.detach().clone().requires_grad_(True)) + + recv = [None, None, None] + + def fwd(k): + recv[k], _, _ = ep_dispatch(buffers[k], tokens_p[k], idx, w) + + def bwd(k): + (0.5 * (recv[k].float() ** 2).sum()).backward() + recv[k] = None + + fwd(0) + fwd(1) + bwd(0) + fwd(2) + bwd(1) + bwd(2) + torch.cuda.synchronize() + for k in range(3): + torch.testing.assert_close( + tokens_p[k].grad.float(), + tokens[k].float() * float(TOP_K), + atol=5e-2, + rtol=5e-2, + ) + + # Caller-supplied output buffers (autograd) + + def test_dispatch_caller_recv_buffers_autograd(self): + """ep_dispatch with caller-supplied recv buffers; fwd+bwd matches default-alloc grads.""" + buf = self._make_buffer() + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + recv_tokens, recv_w, _ = self._make_raw_recv() + tokens_p = tokens.detach().clone().requires_grad_(True) + rt, rw, _tc = ep_dispatch( + buf, tokens_p, topk_idx, w, recv_tokens=recv_tokens, recv_topk_weights=recv_w + ) + self.assertEqual(rt.data_ptr(), recv_tokens.data_ptr()) + self.assertEqual(rw.data_ptr(), recv_w.data_ptr()) + (0.5 * (rt.float() ** 2).sum() + 0.0 * rw.float().sum()).backward() + torch.cuda.synchronize() + torch.testing.assert_close( + tokens_p.grad.float(), tokens.float() * float(TOP_K), atol=5e-2, rtol=5e-2 + ) + + def test_combine_grad_expert_out_autograd(self): + """ep_combine with caller-supplied grad_expert_out; bwd writes into that slot.""" + buf = self._make_buffer() + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + tokens_p = tokens.detach().clone().requires_grad_(True) + recv_t, recv_w, _ = ep_dispatch(buf, tokens_p, topk_idx, w) + eo = self._weighted(recv_t, recv_w) + grad_eo = torch.empty_like(eo) + gp = grad_eo.data_ptr() + out = ep_combine(buf, eo, grad_expert_out=grad_eo) + (0.5 * (out.float() ** 2).sum()).backward() + torch.cuda.synchronize() + self.assertEqual(grad_eo.data_ptr(), gp) + torch.testing.assert_close(out.float(), tokens.float(), atol=5e-2, rtol=5e-2) + torch.testing.assert_close(tokens_p.grad.float(), tokens.float(), atol=5e-2, rtol=5e-2) + + # Input validation + + def test_topk_int32_raises_clear_error(self): + buf = self._make_buffer() + topk_idx_int32 = torch.zeros( + TOKENS_PER_RANK, TOP_K, dtype=torch.int32, device=self.cfg.device + ) + with self.assertRaises(RuntimeError) as cm: + ep_prepare(buf, topk_idx_int32) + msg = str(cm.exception) + self.assertIn("topk_idx", msg) + self.assertIn(".long()", msg) + + +def _init_distributed(): + dist.init_process_group(backend="nccl") + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + try: + from torch.distributed import _symmetric_memory as _symm_mem + + _symm_mem.set_backend("NCCL") + except (ImportError, RuntimeError): + pass + + +if __name__ == "__main__": + _init_distributed() + loader = unittest.TestLoader() + name_filter = os.environ.get("NVTE_EP_TEST_FILTER") + if name_filter: + loader.testMethodPrefix = name_filter + suite = loader.loadTestsFromTestCase(TestEP) + runner = unittest.TextTestRunner(stream=sys.stdout, verbosity=2) + result = runner.run(suite) + dist.barrier() + ep_finalize() + dist.destroy_process_group() + sys.exit(0 if result.wasSuccessful() else 1) diff --git a/tests/pytorch/distributed/run_test_ep.sh b/tests/pytorch/distributed/run_test_ep.sh new file mode 100755 index 0000000000..92d63cff7e --- /dev/null +++ b/tests/pytorch/distributed/run_test_ep.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +# +# Launcher for tests/pytorch/distributed/run_ep.py. Auto-detects GPU count. +# Short timeout by default to surface hangs early. + +set -uo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TE_REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +export PYTHONPATH="${TE_REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +DETECTED_GPUS=$(nvidia-smi -L 2>/dev/null | wc -l) +if [ "${DETECTED_GPUS}" -lt 4 ]; then + echo "EP requires >= 4 GPUs (found ${DETECTED_GPUS}); SKIPPING." + exit 0 +fi +NUM_RANKS="${NVTE_TEST_EP_NUM_RANKS:-${DETECTED_GPUS}}" +if [ "${NUM_RANKS}" -gt 8 ]; then NUM_RANKS=8; fi + +# Short timeout to detect hangs early. +TEST_TIMEOUT_S="${TEST_TIMEOUT_S:-120}" + +# Stage NCCL EP JIT cubins on tmpfs to keep iteration fast. +: ${NCCL_EP_JIT_CACHE_DIR:="${TMPDIR:-/tmp}/nccl_ep_jit_cache_$(id -u)"} +export NCCL_EP_JIT_CACHE_DIR +mkdir -p "$NCCL_EP_JIT_CACHE_DIR" + +SCRIPT="${SCRIPT_DIR}/run_ep.py" +echo "=== Running ${SCRIPT} on ${NUM_RANKS} GPUs (timeout=${TEST_TIMEOUT_S}s) ===" + +# setsid + kill-after so SIGKILL takes down the whole process group, not just torchrun. +setsid timeout --foreground --kill-after=10 --signal=TERM "${TEST_TIMEOUT_S}" \ + torchrun --standalone --nnodes=1 --nproc-per-node="${NUM_RANKS}" \ + "${SCRIPT}" 2>&1 | tee stdout_ep.txt +RC=${PIPESTATUS[0]} +pkill -9 -f "tests/pytorch/distributed/run_ep.py" 2>/dev/null || true + +RET=0 +if [ "${RC}" -ne 0 ]; then + echo "torchrun exited with ${RC}" + RET=1 +fi +# Match unittest failure markers and unhandled Python tracebacks; torchrun +# prefixes per-rank stderr with "[rankN]:" so don't anchor at column 0. +if grep -qE "(^|]:)FAILED|(^|]:)Traceback" stdout_ep.txt; then RET=1; fi +if ! grep -qE "Ran [0-9]+ test|^OK$" stdout_ep.txt; then + echo "ERROR: no test summary — likely hang or early crash" + RET=1 +fi + +if [ -z "${KEEP_EP_LOGS:-}" ]; then rm -f stdout_ep.txt; fi +exit $RET diff --git a/tests/pytorch/distributed/test_ep.py b/tests/pytorch/distributed/test_ep.py new file mode 100644 index 0000000000..81eef9a3c1 --- /dev/null +++ b/tests/pytorch/distributed/test_ep.py @@ -0,0 +1,31 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Pytest driver — spawns run_ep.py under torchrun and asserts the suite passed.""" + +import os +import subprocess +from pathlib import Path + +import pytest +import torch + +TEST_ROOT = Path(__file__).parent.resolve() +WORKER = TEST_ROOT / "run_ep.py" +LAUNCHER = TEST_ROOT / "run_test_ep.sh" + + +@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="EP requires >= 4 GPUs") +def test_multi_process_ep(): + """Launch the EP unit-test suite across all visible GPUs. + + Short timeout so a hang on any rank surfaces fast rather than burning CI time. + """ + timeout_s = int(os.environ.get("NVTE_TEST_EP_TIMEOUT_S", "180")) + proc = subprocess.run( + ["bash", str(LAUNCHER)], + env={**os.environ, "KEEP_EP_LOGS": "1", "TEST_TIMEOUT_S": str(timeout_s)}, + timeout=timeout_s + 30, + check=False, + ) + assert proc.returncode == 0, f"EP test suite failed (rc={proc.returncode})" diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 2b4f899e1d..cf1481a1b6 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -9,6 +9,7 @@ #include +#include #include #include #include @@ -647,6 +648,42 @@ void inplace_multi_tensor_swizzle_scales_for_gemm_unchecked(std::vector +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "transformer_engine/comm_window.h" + +#ifdef NCCL_HAS_SYMMEM_SUPPORT +#include +#endif + +#include "../common.h" +#include "../extensions.h" +#include "transformer_engine/gemm.h" + +namespace transformer_engine::pytorch { + +namespace { + +// EP process group name, captured at ep_initialize. Used by the symm-mem +// window resolver below to look up SymmetricMemory for payload tensors. +// Empty until ep_initialize. +std::string g_ep_group_name; // NOLINT(runtime/string) + +// True while the EP backend holds a borrowed reference to torch's NCCL comm. +bool g_ep_initialized = false; + +// Zero-copy IO toggle captured at ep_initialize. Atomic so the Python-side +// toggle is safe against concurrent ep_dispatch/combine (which release the GIL). +std::atomic g_zero_copy_enabled{false}; + +// Sentinel returned by maybe_make_window when zero-copy is off or the tensor +// is not symm-mem-backed; the backend treats it as "no window, use staged copy". +constexpr NVTECommWindow kNoWindow = {nullptr, 0}; + +// Resolve ``t`` to an NCCL symm-mem window for the zero-copy one-sided path. +// Returns ``kNoWindow`` when symm-mem support isn't compiled in, zero-copy is +// disabled, no group is set, or ``t`` isn't symm-mem-backed; callers pass the +// resulting window unconditionally to the backend. +NVTECommWindow maybe_make_window(const at::Tensor& t) { +#ifdef NCCL_HAS_SYMMEM_SUPPORT + if (!g_zero_copy_enabled.load(std::memory_order_relaxed)) return kNoWindow; + if (g_ep_group_name.empty()) return kNoWindow; + c10::intrusive_ptr sm; + try { + sm = c10d::symmetric_memory::rendezvous(t, g_ep_group_name); + } catch (const std::exception&) { + return kNoWindow; // Tensor not symm-mem-backed; fall back to staged copy. + } + if (sm == nullptr) return kNoWindow; + auto* nccl_sm = dynamic_cast(sm.get()); + NVTE_CHECK(nccl_sm != nullptr, + "Symm-mem backend mismatch: expected NCCLSymmetricMemory. Set the backend to " + "\"NCCL\" before allocating EP payload buffers."); + return NVTECommWindow{static_cast(nccl_sm->get_window()), + static_cast(nccl_sm->get_offset())}; +#else + (void)t; + return kNoWindow; +#endif +} + +// When zero-copy is enabled, the named tensor must be symm-mem-backed on the +// EP group. Throws a clear error otherwise. No-op when zero-copy is off or +// symm-mem support isn't compiled in. Mirrors maybe_make_window's resolution +// path but turns the "not symm-mem" outcome into a hard error. +void check_symm_mem_required(const at::Tensor& t, const char* name) { +#ifdef NCCL_HAS_SYMMEM_SUPPORT + if (!g_zero_copy_enabled.load(std::memory_order_relaxed)) return; + NVTE_CHECK(!g_ep_group_name.empty(), + "Zero-copy is enabled but EP group name is unset; call ep_initialize first."); + c10::intrusive_ptr sm; + try { + sm = c10d::symmetric_memory::rendezvous(t, g_ep_group_name); + } catch (const std::exception&) { + sm = nullptr; + } + NVTE_CHECK(sm != nullptr, "ep zero-copy: ", name, + " must be symm-mem-backed on the EP group (allocate via symm_mem_alloc)."); +#else + (void)t; + (void)name; +#endif +} + +// The backend only accepts int64 topk_idx. The PyTorch wrapper enforces this +// at the boundary so the per-step ops don't need an upcast workspace. +void check_topk_idx_int64(at::Tensor topk_idx) { + NVTE_CHECK(topk_idx.is_contiguous(), "topk_idx must be contiguous"); + NVTE_CHECK(topk_idx.scalar_type() == at::kLong, + "topk_idx must be int64; got dtype=", c10::toString(topk_idx.scalar_type()), + ". Cast with topk_idx.long() before calling."); +} + +using Shape = std::vector; + +} // namespace + +bool ep_get_zero_copy() { return g_zero_copy_enabled.load(std::memory_order_relaxed); } + +// ── Bootstrap ──────────────────────────────────────────────────────────────── +// Borrows torch's NCCL host comm (from ``ProcessGroupNCCL._comm_ptr()``). +// ``group_name`` is captured for the symm-mem window resolver. + +void ep_initialize(uintptr_t comm_ptr, const std::string& group_name, int64_t num_experts, + int64_t max_tokens_per_rank, int64_t max_recv_tokens_per_rank, + int64_t hidden_dim, int64_t max_num_sms, pybind11::object max_token_dtype, + bool zero_copy) { + NVTE_CHECK(!group_name.empty(), "group_name must be non-empty (used for symm-mem lookup)"); + NVTE_CHECK(comm_ptr != 0, "comm_ptr must be non-null (torch NCCL host comm pointer)"); + NVTE_CHECK(!g_ep_initialized, "ep_initialize called twice without ep_finalize"); + + auto ep_comm = reinterpret_cast(comm_ptr); + int ep_size = 0; + NVTE_CHECK(ncclCommCount(ep_comm, &ep_size) == ncclSuccess, "ncclCommCount failed"); + auto torch_dtype = max_token_dtype.cast(); + NVTEEpGroupConfig cfg{ + /*ep_size=*/ep_size, + /*num_experts=*/static_cast(num_experts), + /*max_tokens_per_rank=*/static_cast(max_tokens_per_rank), + /*max_recv_tokens_per_rank=*/static_cast(max_recv_tokens_per_rank), + /*hidden_dim=*/static_cast(hidden_dim), + /*max_num_sms=*/static_cast(max_num_sms), + /*max_token_dtype=*/static_cast(GetTransformerEngineDType(torch_dtype)), + /*zero_copy=*/zero_copy ? 1 : 0, + }; + nvte_ep_initialize(static_cast(ep_comm), cfg); + g_zero_copy_enabled.store(zero_copy, std::memory_order_relaxed); + g_ep_initialized = true; + g_ep_group_name = group_name; +} + +void ep_finalize() { + if (!g_ep_initialized) return; + // The borrowed comm is owned by torch's symm-mem layer; don't destroy it. + nvte_ep_shutdown(); + g_ep_initialized = false; + g_ep_group_name.clear(); + g_zero_copy_enabled.store(false, std::memory_order_relaxed); +} + +namespace { + +NVTEEpLayerConfig make_layer_cfg(int64_t top_k, int64_t dispatch_output_per_expert_alignment) { + return NVTEEpLayerConfig{ + /*top_k=*/static_cast(top_k), + /*dispatch_output_per_expert_alignment=*/ + static_cast(dispatch_output_per_expert_alignment), + }; +} + +} // namespace + +int64_t ep_handle_mem_size(int64_t top_k, int64_t dispatch_output_per_expert_alignment) { + return static_cast( + nvte_ep_handle_mem_size(make_layer_cfg(top_k, dispatch_output_per_expert_alignment))); +} + +// ── Per-step ops ───────────────────────────────────────────────────────────── + +void ep_prepare(at::Tensor handle_mem, at::Tensor topk_idx, at::Tensor token_counts, int64_t top_k, + int64_t dispatch_output_per_expert_alignment) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + NVTE_CHECK(topk_idx.dim() >= 2, "topk_idx must be at least 2D [..., top_k]"); + check_topk_idx_int64(topk_idx); + const size_t T_flat = topk_idx.numel() / topk_idx.size(-1); + const size_t topk_n = static_cast(topk_idx.size(-1)); + + auto topk_idx_te = + makeTransformerEngineTensor(topk_idx.data_ptr(), Shape{T_flat, topk_n}, DType::kInt64); + auto token_counts_te = makeTransformerEngineTensor( + token_counts.data_ptr(), Shape{static_cast(token_counts.numel())}, DType::kInt32); + auto handle_mem_te = makeTransformerEngineTensor( + handle_mem.data_ptr(), Shape{static_cast(handle_mem.numel())}, DType::kByte); + + nvte_ep_prepare(handle_mem_te.data(), topk_idx_te.data(), token_counts_te.data(), + make_layer_cfg(top_k, dispatch_output_per_expert_alignment), stream); +} + +void ep_dispatch(at::Tensor handle_mem, at::Tensor topk_idx, at::Tensor tokens, + at::Tensor topk_weights, at::Tensor recv_tokens, at::Tensor recv_topk_weights) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + NVTE_CHECK(tokens.dim() >= 2, "tokens must be at least 2D [..., H]"); + NVTE_CHECK(topk_idx.dim() >= 2, "topk_idx must be at least 2D [..., top_k]"); + NVTE_CHECK(topk_weights.dim() >= 2, "topk_weights must be at least 2D [..., top_k]"); + NVTE_CHECK(recv_tokens.dim() >= 2, "recv_tokens must be at least 2D [..., recv_pr, H]"); + check_topk_idx_int64(topk_idx); + NVTE_CHECK(tokens.is_contiguous(), "tokens must be contiguous"); + NVTE_CHECK(topk_weights.is_contiguous(), "topk_weights must be contiguous"); + NVTE_CHECK(recv_tokens.is_contiguous(), "recv_tokens must be contiguous"); + NVTE_CHECK(recv_topk_weights.is_contiguous(), "recv_topk_weights must be contiguous"); + + const size_t H = static_cast(tokens.size(-1)); + const size_t T_flat = tokens.numel() / H; + const size_t topk_n = static_cast(topk_idx.size(-1)); + const size_t recv_pr = recv_tokens.numel() / H; + + NVTE_CHECK(static_cast(topk_weights.size(-1)) == topk_n, + "topk_weights last dim must equal topk_idx last dim"); + NVTE_CHECK(static_cast(topk_idx.numel()) == T_flat * topk_n, + "topk_idx token count must equal tokens token count"); + NVTE_CHECK(static_cast(topk_weights.numel()) == T_flat * topk_n, + "topk_weights token count must equal tokens token count"); + NVTE_CHECK(static_cast(recv_topk_weights.numel()) == recv_pr, + "recv_topk_weights total size must equal recv_tokens recv_pr"); + NVTE_CHECK(recv_tokens.scalar_type() == tokens.scalar_type(), "recv_tokens dtype (", + c10::toString(recv_tokens.scalar_type()), ") must match tokens dtype (", + c10::toString(tokens.scalar_type()), ")"); + check_symm_mem_required(recv_tokens, "recv_tokens"); + check_symm_mem_required(recv_topk_weights, "recv_topk_weights"); + + auto tok_dtype = GetTransformerEngineDType(tokens.scalar_type()); + auto handle_mem_te = makeTransformerEngineTensor( + handle_mem.data_ptr(), Shape{static_cast(handle_mem.numel())}, DType::kByte); + auto topk_idx_te = + makeTransformerEngineTensor(topk_idx.data_ptr(), Shape{T_flat, topk_n}, DType::kInt64); + auto tokens_te = makeTransformerEngineTensor(tokens.data_ptr(), Shape{T_flat, H}, tok_dtype); + auto topk_w_te = + makeTransformerEngineTensor(topk_weights.data_ptr(), Shape{T_flat, topk_n}, DType::kFloat32); + auto recv_tokens_te = + makeTransformerEngineTensor(recv_tokens.data_ptr(), Shape{recv_pr, H}, tok_dtype); + auto recv_topk_w_te = + makeTransformerEngineTensor(recv_topk_weights.data_ptr(), Shape{recv_pr}, DType::kFloat32); + + // top_k / alignment are carried by the cached layer_cfg seeded at ep_prepare; + // per-step ops look them up by handle_mem pointer in the backend. + NVTECommWindow tokens_win = maybe_make_window(tokens); + NVTECommWindow topk_w_win = maybe_make_window(topk_weights); + NVTECommWindow recv_tokens_win = maybe_make_window(recv_tokens); + NVTECommWindow recv_topk_w_win = maybe_make_window(recv_topk_weights); + nvte_ep_dispatch(handle_mem_te.data(), topk_idx_te.data(), tokens_te.data(), tokens_win, + topk_w_te.data(), topk_w_win, recv_tokens_te.data(), recv_tokens_win, + recv_topk_w_te.data(), recv_topk_w_win, stream); +} + +void ep_combine(at::Tensor handle_mem, at::Tensor expert_out, at::Tensor result) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + NVTE_CHECK(expert_out.dim() >= 2, "expert_out must be at least 2D [..., recv_pr, H]"); + NVTE_CHECK(result.dim() >= 2, "result must be at least 2D [..., H]"); + NVTE_CHECK(expert_out.is_contiguous(), "expert_out must be contiguous"); + + const size_t H = static_cast(expert_out.size(-1)); + const size_t recv_pr = expert_out.numel() / H; + const size_t T_flat = result.numel() / H; + NVTE_CHECK(static_cast(result.size(-1)) == H, + "result hidden dim must equal expert_out hidden dim"); + NVTE_CHECK(result.scalar_type() == expert_out.scalar_type(), "result dtype (", + c10::toString(result.scalar_type()), ") must match expert_out dtype (", + c10::toString(expert_out.scalar_type()), ")"); + check_symm_mem_required(expert_out, "expert_out"); + + auto eo_dtype = GetTransformerEngineDType(expert_out.scalar_type()); + auto handle_mem_te = makeTransformerEngineTensor( + handle_mem.data_ptr(), Shape{static_cast(handle_mem.numel())}, DType::kByte); + auto expert_out_te = + makeTransformerEngineTensor(expert_out.data_ptr(), Shape{recv_pr, H}, eo_dtype); + auto result_te = makeTransformerEngineTensor(result.data_ptr(), Shape{T_flat, H}, eo_dtype); + + NVTECommWindow expert_out_win = maybe_make_window(expert_out); + nvte_ep_combine(handle_mem_te.data(), expert_out_te.data(), expert_out_win, result_te.data(), + stream); +} + +void ep_dispatch_bwd(at::Tensor handle_mem, at::Tensor grad, at::Tensor g_recv_topk_weights, + at::Tensor grad_tokens, at::Tensor grad_topk_weights) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + NVTE_CHECK(grad.dim() >= 2, "grad must be at least 2D [..., recv_pr, H]"); + NVTE_CHECK(grad_tokens.dim() >= 2, "grad_tokens must be at least 2D [..., H]"); + NVTE_CHECK(grad_topk_weights.dim() >= 2, "grad_topk_weights must be at least 2D [..., top_k]"); + NVTE_CHECK(grad.is_contiguous(), "grad must be contiguous"); + NVTE_CHECK(g_recv_topk_weights.is_contiguous(), "g_recv_topk_weights must be contiguous"); + + const size_t H = static_cast(grad.size(-1)); + const size_t recv_pr = grad.numel() / H; + const size_t T_flat = grad_tokens.numel() / H; + const size_t topk_n = static_cast(grad_topk_weights.size(-1)); + NVTE_CHECK(static_cast(g_recv_topk_weights.numel()) == recv_pr, + "g_recv_topk_weights total size must equal grad recv_pr"); + NVTE_CHECK(static_cast(grad_tokens.size(-1)) == H, + "grad_tokens hidden dim must equal grad H"); + NVTE_CHECK(static_cast(grad_topk_weights.numel()) == T_flat * topk_n, + "grad_topk_weights numel (", grad_topk_weights.numel(), + ") must equal T_flat * top_k (", T_flat * topk_n, ")"); + NVTE_CHECK(grad_tokens.scalar_type() == grad.scalar_type(), "grad_tokens dtype (", + c10::toString(grad_tokens.scalar_type()), ") must match grad dtype (", + c10::toString(grad.scalar_type()), ")"); + check_symm_mem_required(grad, "grad (dispatch_bwd input)"); + check_symm_mem_required(g_recv_topk_weights, "g_recv_topk_weights"); + + auto g_dtype = GetTransformerEngineDType(grad.scalar_type()); + auto handle_mem_te = makeTransformerEngineTensor( + handle_mem.data_ptr(), Shape{static_cast(handle_mem.numel())}, DType::kByte); + auto grad_te = makeTransformerEngineTensor(grad.data_ptr(), Shape{recv_pr, H}, g_dtype); + auto g_recv_w_te = + makeTransformerEngineTensor(g_recv_topk_weights.data_ptr(), Shape{recv_pr}, DType::kFloat32); + auto grad_tokens_te = + makeTransformerEngineTensor(grad_tokens.data_ptr(), Shape{T_flat, H}, g_dtype); + auto grad_topk_w_te = makeTransformerEngineTensor(grad_topk_weights.data_ptr(), + Shape{T_flat, topk_n}, DType::kFloat32); + + NVTECommWindow grad_win = maybe_make_window(grad); + NVTECommWindow g_recv_w_win = maybe_make_window(g_recv_topk_weights); + nvte_ep_dispatch_bwd(handle_mem_te.data(), grad_te.data(), grad_win, g_recv_w_te.data(), + g_recv_w_win, grad_tokens_te.data(), grad_topk_w_te.data(), stream); +} + +void ep_combine_bwd(at::Tensor handle_mem, at::Tensor grad, at::Tensor grad_expert_out) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + NVTE_CHECK(grad.dim() >= 2, "grad must be at least 2D [..., H]"); + NVTE_CHECK(grad_expert_out.dim() >= 2, "grad_expert_out must be at least 2D [..., recv_pr, H]"); + NVTE_CHECK(grad.is_contiguous(), "grad must be contiguous"); + NVTE_CHECK(grad_expert_out.is_contiguous(), "grad_expert_out must be contiguous"); + + const size_t H = static_cast(grad.size(-1)); + const size_t T_flat = grad.numel() / H; + const size_t recv_pr = grad_expert_out.numel() / H; + NVTE_CHECK(static_cast(grad_expert_out.size(-1)) == H, + "grad_expert_out hidden dim must match grad H"); + NVTE_CHECK(grad_expert_out.scalar_type() == grad.scalar_type(), "grad_expert_out dtype (", + c10::toString(grad_expert_out.scalar_type()), ") must match grad dtype (", + c10::toString(grad.scalar_type()), ")"); + check_symm_mem_required(grad, "grad (combine_bwd input)"); + check_symm_mem_required(grad_expert_out, "grad_expert_out"); + + auto g_dtype = GetTransformerEngineDType(grad.scalar_type()); + auto handle_mem_te = makeTransformerEngineTensor( + handle_mem.data_ptr(), Shape{static_cast(handle_mem.numel())}, DType::kByte); + auto grad_te = makeTransformerEngineTensor(grad.data_ptr(), Shape{T_flat, H}, g_dtype); + auto grad_eo_te = + makeTransformerEngineTensor(grad_expert_out.data_ptr(), Shape{recv_pr, H}, g_dtype); + + NVTECommWindow grad_win = maybe_make_window(grad); + NVTECommWindow grad_eo_win = maybe_make_window(grad_expert_out); + nvte_ep_combine_bwd(handle_mem_te.data(), grad_te.data(), grad_win, grad_eo_te.data(), + grad_eo_win, stream); +} + +void register_ep_bindings(pybind11::module_& m) { + namespace py = pybind11; + m.def("ep_initialize", &ep_initialize, + "Initialize the EP backend; borrows torch's NCCL comm pointed to by ``comm_ptr``.", + py::arg("comm_ptr"), py::arg("group_name"), py::arg("num_experts"), + py::arg("max_tokens_per_rank"), py::arg("max_recv_tokens_per_rank"), py::arg("hidden_dim"), + py::arg("max_num_sms") = 0, py::arg("max_token_dtype"), py::arg("zero_copy") = false, + py::call_guard()); + m.def("ep_finalize", &ep_finalize, "Tear down the EP backend. Idempotent.", + py::call_guard()); + m.def("ep_get_zero_copy", &ep_get_zero_copy, "Return the current EP zero-copy toggle state."); + m.def("ep_handle_mem_size", &ep_handle_mem_size, + "Return the handle_mem byte size for the given layer config.", py::arg("top_k"), + py::arg("dispatch_output_per_expert_alignment") = 0); + m.def("ep_prepare", &ep_prepare, "EP prepare", py::call_guard()); + m.def("ep_dispatch", &ep_dispatch, "EP dispatch", py::call_guard()); + m.def("ep_combine", &ep_combine, "EP combine", py::call_guard()); + m.def("ep_dispatch_bwd", &ep_dispatch_bwd, "EP dispatch backward", + py::call_guard()); + m.def("ep_combine_bwd", &ep_combine_bwd, "EP combine backward", + py::call_guard()); +} + +} // namespace transformer_engine::pytorch + +#endif // NVTE_WITH_NCCL_EP diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index d6089b1e01..e55b6defa0 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -292,6 +292,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "DSquaredReLU + DBias + Quantize", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); +#ifdef NVTE_WITH_NCCL_EP + transformer_engine::pytorch::register_ep_bindings(m); +#endif // NVTE_WITH_NCCL_EP + // Permutation functions m.def("moe_permute_fwd", transformer_engine::pytorch::moe_permute_fwd, "MOE permute FWD", py::call_guard()); diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py new file mode 100644 index 0000000000..0caf3b8c8b --- /dev/null +++ b/transformer_engine/pytorch/ep.py @@ -0,0 +1,621 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""PyTorch Expert Parallelism (EP) API.""" + +from __future__ import annotations + +import atexit +import warnings +from typing import Optional + +import torch +import torch.distributed as dist + +import transformer_engine_torch as tex + + +__all__ = [ + "EpBuffer", + "ep_bootstrap", + "ep_finalize", + "ep_dispatch", + "ep_combine", + "symm_mem_alloc", +] + + +# Symmetric-memory buffer allocator +# +# Used for the symm-mem zero-copy IO path. Set ``ep_bootstrap(zero_copy=True)`` +# to opt in; the C++ backend then operates the EP group in zero-copy mode. + + +def symm_mem_alloc( + shape, + dtype: torch.dtype, + ep_group: dist.ProcessGroup, + device: Optional[torch.device] = None, +) -> torch.Tensor: + """Allocate and rendezvous a symm-mem buffer on ep_group. Collective on ep_group.""" + if device is None: + device = torch.device("cuda", torch.cuda.current_device()) + try: + from torch.distributed import _symmetric_memory as _symm_mem + except ImportError as e: + raise RuntimeError( + "torch.distributed._symmetric_memory is unavailable; symm_mem_alloc " + "requires PyTorch built with NCCL symm-mem support." + ) from e + if _symm_mem.get_backend(device) != "NCCL": + _symm_mem.set_backend("NCCL") + t = _symm_mem.empty(*shape, dtype=dtype, device=device) + _symm_mem.rendezvous(t, group=ep_group.group_name) + return t + + +# Bootstrap + + +# NCCL EP requires NCCL >= 2.30.4 (matches the C++ backend's runtime check). +_MIN_NCCL_VERSION = (2, 30, 4) + + +def _check_nccl_runtime_version() -> None: + """Raise with a clear message if the loaded libnccl is too old for NCCL EP.""" + import ctypes + + try: + lib = ctypes.CDLL("libnccl.so.2", mode=ctypes.RTLD_GLOBAL) + v = ctypes.c_int(0) + if lib.ncclGetVersion(ctypes.byref(v)) != 0: + import warnings + + warnings.warn("ncclGetVersion failed; skipping NCCL EP version check.") + return + except OSError: # libnccl not findable; let the C++ side error + return + n = v.value + # NCCL packs as (major*10000 + minor*100 + patch) up to ~2.x; newer + # builds use the same scheme. Decode defensively. + major, minor, patch = n // 10000, (n // 100) % 100, n % 100 + if (major, minor, patch) < _MIN_NCCL_VERSION: + min_str = ".".join(str(x) for x in _MIN_NCCL_VERSION) + raise RuntimeError( + f"NCCL EP requires NCCL >= {min_str}, found {major}.{minor}.{patch} at runtime. " + "Set LD_LIBRARY_PATH to a newer libnccl.so before launching." + ) + + +_BOOTSTRAPPED = False +_ATEXIT_REGISTERED = False + + +def _atexit_finalize() -> None: + """Best-effort teardown at interpreter shutdown; swallows errors.""" + global _BOOTSTRAPPED + if _BOOTSTRAPPED: + try: + tex.ep_finalize() + except Exception: # pylint: disable=broad-exception-caught + import traceback + + traceback.print_exc() + finally: + _BOOTSTRAPPED = False + + +def ep_bootstrap( + ep_group: dist.ProcessGroup, + num_experts: int, + max_tokens_per_rank: int, + recv_capacity_per_rank: int, + hidden_dim: int, + max_num_sms: int = 0, + zero_copy: bool = False, + max_token_dtype: torch.dtype = torch.bfloat16, +) -> None: + """Initialize EP by borrowing ep_group's NCCL comm. Call once per process. + + max_token_dtype sets the widest token dtype this EP group will dispatch; + it sizes NCCL EP staging buffers. + + ``zero_copy`` opts the EP group into the symm-mem zero-copy IO path; pass + ``True`` only when payload tensors are allocated via ``symm_mem_alloc``. + Defaults to ``False``. + """ + global _BOOTSTRAPPED, _ATEXIT_REGISTERED + if _BOOTSTRAPPED: + raise RuntimeError("ep_bootstrap was already called in this process") + if ep_group.size() < 2: + raise ValueError(f"ep_bootstrap requires ep_group.size() >= 2 (got {ep_group.size()}).") + _check_nccl_runtime_version() + if zero_copy: + warnings.warn( + "ep_bootstrap(zero_copy=True) is experimental; the symm-mem IO path " + "and its alias contracts on EpBuffer slots are subject to change.", + stacklevel=2, + ) + + # Materialize the PG's NCCL comm before borrowing its raw handle. + dist.barrier(group=ep_group, device_ids=[torch.cuda.current_device()]) + comm_ptr = ep_group._get_backend(torch.device("cuda"))._comm_ptr() + + tex.ep_initialize( + int(comm_ptr), + str(ep_group.group_name), + int(num_experts), + int(max_tokens_per_rank), + int(recv_capacity_per_rank), + int(hidden_dim), + int(max_num_sms), + max_token_dtype, + bool(zero_copy), + ) + _BOOTSTRAPPED = True + if not _ATEXIT_REGISTERED: + atexit.register(_atexit_finalize) + _ATEXIT_REGISTERED = True + + +def ep_finalize() -> None: + """Optional explicit EP teardown; idempotent. + + An atexit handler covers normal interpreter shutdown, so most users do not + need to call this. Call it explicitly only before + ``dist.destroy_process_group()``, since the borrowed NCCL comm becomes + invalid once the PG is destroyed. + """ + global _BOOTSTRAPPED + if not _BOOTSTRAPPED: + return + try: + tex.ep_finalize() + finally: + _BOOTSTRAPPED = False + + +# Buffer + + +class EpBuffer: + """Per-microbatch EP layer state holding handle_mem and token_counts. + Cross-rank payload buffers are caller-supplied to ep_dispatch and + ep_combine; allocate via symm_mem_alloc in zero-copy mode. + Use one EpBuffer per concurrently-in-flight call (e.g. per PP-1F1B microbatch). + """ + + __slots__ = ( + "handle_mem", + "top_k", + "alignment", + "max_tokens_per_rank", + "recv_capacity_per_rank", + "hidden_dim", + "num_local_experts", + "payload_dtype", + "device", + "token_counts", + "zero_copy", + ) + + def __init__( + self, + top_k: int, + max_tokens_per_rank: int, + recv_capacity_per_rank: int, + hidden_dim: int, + num_local_experts: int, + alignment: int = 0, + payload_dtype: torch.dtype = torch.bfloat16, + device: Optional[torch.device] = None, + ) -> None: + if device is None: + device = torch.device("cuda", torch.cuda.current_device()) + alignment = int(alignment) + if alignment > 1 and (alignment & (alignment - 1)) != 0: + raise ValueError(f"alignment must be 0, 1, or a power of two (got {alignment}).") + self.top_k = int(top_k) + self.alignment = alignment + self.max_tokens_per_rank = int(max_tokens_per_rank) + self.recv_capacity_per_rank = int(recv_capacity_per_rank) + self.hidden_dim = int(hidden_dim) + self.num_local_experts = int(num_local_experts) + self.payload_dtype = payload_dtype + self.device = device + self.zero_copy = bool(tex.ep_get_zero_copy()) + + size_bytes = tex.ep_handle_mem_size(self.top_k, self.alignment) + self.handle_mem = torch.empty(int(size_bytes), dtype=torch.uint8, device=device) + self.token_counts = torch.empty(self.num_local_experts, dtype=torch.int32, device=device) + + @classmethod + def from_external( + cls, + top_k: int, + max_tokens_per_rank: int, + recv_capacity_per_rank: int, + hidden_dim: int, + num_local_experts: int, + *, + token_counts: Optional[torch.Tensor] = None, + alignment: int = 0, + payload_dtype: torch.dtype = torch.bfloat16, + device: Optional[torch.device] = None, + ) -> "EpBuffer": + """Construct from a caller-allocated token_counts; handle_mem is always fresh.""" + if device is None: + device = torch.device("cuda", torch.cuda.current_device()) + alignment = int(alignment) + if alignment > 1 and (alignment & (alignment - 1)) != 0: + raise ValueError(f"alignment must be 0, 1, or a power of two (got {alignment}).") + counts_shape = (num_local_experts,) + + inst = cls.__new__(cls) + inst.top_k = int(top_k) + inst.alignment = alignment + inst.max_tokens_per_rank = int(max_tokens_per_rank) + inst.recv_capacity_per_rank = int(recv_capacity_per_rank) + inst.hidden_dim = int(hidden_dim) + inst.num_local_experts = int(num_local_experts) + inst.payload_dtype = payload_dtype + inst.device = device + inst.zero_copy = bool(tex.ep_get_zero_copy()) + + size_bytes = tex.ep_handle_mem_size(inst.top_k, inst.alignment) + inst.handle_mem = torch.empty(int(size_bytes), dtype=torch.uint8, device=device) + + if token_counts is not None: + if tuple(token_counts.shape) != counts_shape: + raise ValueError( + f"token_counts shape {tuple(token_counts.shape)} != expected {counts_shape}" + ) + if token_counts.dtype != torch.int32: + raise ValueError(f"token_counts dtype {token_counts.dtype} != expected int32") + inst.token_counts = token_counts + else: + inst.token_counts = torch.empty(counts_shape, dtype=torch.int32, device=device) + return inst + + def record_stream(self, stream: torch.cuda.Stream) -> None: + """Defer caching-allocator reclaim of owned tensors until stream catches up.""" + self.handle_mem.record_stream(stream) + self.token_counts.record_stream(stream) + + +# torch.library custom ops (so they don't graph-break under torch.compile) + +_LIB = "transformer_engine_ep" + + +@torch.library.custom_op( + f"{_LIB}::prepare", + mutates_args=("handle_mem", "token_counts"), + device_types="cuda", +) +def _prepare_op( + handle_mem: torch.Tensor, + top_k: int, + topk_idx: torch.Tensor, + token_counts: torch.Tensor, + alignment: int, +) -> None: + tex.ep_prepare(handle_mem, topk_idx, token_counts, top_k, alignment) + + +@_prepare_op.register_fake +def _(*_args, **_kw): + return None + + +@torch.library.custom_op( + f"{_LIB}::dispatch", + mutates_args=("recv_tokens", "recv_topk_weights"), + device_types="cuda", +) +def _dispatch_op( + handle_mem: torch.Tensor, + topk_idx: torch.Tensor, + tokens: torch.Tensor, + topk_weights: torch.Tensor, + recv_tokens: torch.Tensor, + recv_topk_weights: torch.Tensor, +) -> None: + tex.ep_dispatch(handle_mem, topk_idx, tokens, topk_weights, recv_tokens, recv_topk_weights) + + +@_dispatch_op.register_fake +def _(*_args, **_kw): + return None + + +@torch.library.custom_op( + f"{_LIB}::combine", + mutates_args=("result",), + device_types="cuda", +) +def _combine_op( + handle_mem: torch.Tensor, + expert_out: torch.Tensor, + result: torch.Tensor, +) -> None: + tex.ep_combine(handle_mem, expert_out, result) + + +@_combine_op.register_fake +def _(*_args, **_kw): + return None + + +@torch.library.custom_op( + f"{_LIB}::dispatch_bwd", + mutates_args=("grad_tokens", "grad_topk_weights"), + device_types="cuda", +) +def _dispatch_bwd_op( + handle_mem: torch.Tensor, + grad: torch.Tensor, + g_recv_topk_weights: torch.Tensor, + grad_tokens: torch.Tensor, + grad_topk_weights: torch.Tensor, +) -> None: + tex.ep_dispatch_bwd(handle_mem, grad, g_recv_topk_weights, grad_tokens, grad_topk_weights) + + +@_dispatch_bwd_op.register_fake +def _(*_args, **_kw): + return None + + +@torch.library.custom_op( + f"{_LIB}::combine_bwd", + mutates_args=("grad_expert_out",), + device_types="cuda", +) +def _combine_bwd_op( + handle_mem: torch.Tensor, + grad: torch.Tensor, + grad_expert_out: torch.Tensor, +) -> None: + tex.ep_combine_bwd(handle_mem, grad, grad_expert_out) + + +@_combine_bwd_op.register_fake +def _(*_args, **_kw): + return None + + +# Non-autograd primitives + + +def ep_prepare(buffer: "EpBuffer", topk_idx: torch.Tensor) -> torch.Tensor: + """AllGather the routing map; fills ``buffer.handle_mem`` and returns + ``buffer.token_counts`` (int32, shape [num_local_experts]). topk_idx must + be int64. + """ + torch.ops.transformer_engine_ep.prepare( + buffer.handle_mem, buffer.top_k, topk_idx, buffer.token_counts, buffer.alignment + ) + return buffer.token_counts + + +def _ep_dispatch_raw( + buffer: "EpBuffer", + topk_idx: torch.Tensor, + tokens: torch.Tensor, + topk_weights: torch.Tensor, + recv_tokens: torch.Tensor, + recv_topk_weights: torch.Tensor, +) -> None: + """Raw dispatch; no autograd, no prepare. Caller must run ep_prepare first.""" + tex.ep_dispatch( + buffer.handle_mem, topk_idx, tokens, topk_weights, recv_tokens, recv_topk_weights + ) + + +def _ep_combine_raw(buffer: "EpBuffer", expert_out: torch.Tensor, result: torch.Tensor) -> None: + """Raw combine; no autograd. Caller pre-weights expert_out.""" + tex.ep_combine(buffer.handle_mem, expert_out, result) + + +# autograd.Function wrappers + + +class _EpDispatch(torch.autograd.Function): + """Autograd prepare+dispatch; bwd uses user-supplied grad inputs as-is.""" + + @staticmethod + def forward( # type: ignore[override] + ctx, + handle_mem: torch.Tensor, + top_k: int, + alignment: int, + recv_tokens: torch.Tensor, + recv_topk_weights: torch.Tensor, + token_counts: torch.Tensor, + topk_idx: torch.Tensor, + tokens: torch.Tensor, + topk_weights: torch.Tensor, + ): + """Prepare + dispatch fwd.""" + torch.ops.transformer_engine_ep.prepare( + handle_mem, top_k, topk_idx, token_counts, alignment + ) + torch.ops.transformer_engine_ep.dispatch( + handle_mem, + topk_idx, + tokens, + topk_weights, + recv_tokens, + recv_topk_weights, + ) + ctx.handle_mem = handle_mem + ctx.tokens_shape = tokens.shape + ctx.tokens_dtype = tokens.dtype + ctx.topk_weights_shape = topk_weights.shape + ctx.tokens_T_flat = tokens.numel() // tokens.shape[-1] + ctx.topk_T_flat = topk_weights.numel() // topk_weights.shape[-1] + ctx.top_k = topk_weights.shape[-1] + ctx.recv_capacity = recv_tokens.shape[0] + ctx.hidden_dim = tokens.shape[-1] + ctx.mark_non_differentiable(token_counts) + # Detach so the long-lived buffers aren't tracked as differentiable outputs; + # autograd re-attaches grad_fn pointing back at this Function. + return recv_tokens.detach(), recv_topk_weights.detach(), token_counts + + @staticmethod + def backward(ctx, g_recv_tokens, g_recv_topk_weights, _g_token_counts): # type: ignore[override] + """Dispatch bwd; normalizes grad-input layout, otherwise passes through.""" + device = ctx.handle_mem.device + g_recv_tokens = g_recv_tokens.contiguous() + g_recv_topk_weights = g_recv_topk_weights.contiguous() + grad_tokens = torch.empty( + ctx.tokens_T_flat, ctx.hidden_dim, dtype=ctx.tokens_dtype, device=device + ) + grad_topk_weights = torch.empty( + ctx.topk_T_flat, ctx.top_k, dtype=torch.float32, device=device + ) + torch.ops.transformer_engine_ep.dispatch_bwd( + ctx.handle_mem, + g_recv_tokens, + g_recv_topk_weights, + grad_tokens, + grad_topk_weights, + ) + return ( + None, # handle_mem + None, # top_k + None, # alignment + None, # recv_tokens + None, # recv_topk_weights + None, # token_counts + None, # topk_idx + grad_tokens.view(ctx.tokens_shape), + grad_topk_weights.view(ctx.topk_weights_shape), + ) + + +class _EpCombine(torch.autograd.Function): + """Autograd combine; bwd writes into grad_expert_out, or expert_out's storage if None.""" + + @staticmethod + def forward( # type: ignore[override] + ctx, + handle_mem: torch.Tensor, + num_local_tokens: int, + hidden_dim: int, + grad_expert_out: Optional[torch.Tensor], + expert_out: torch.Tensor, + ): + """Combine fwd; stashes grad_expert_out (or expert_out) as the bwd output slot.""" + device = expert_out.device + result = torch.empty(num_local_tokens, hidden_dim, dtype=expert_out.dtype, device=device) + torch.ops.transformer_engine_ep.combine(handle_mem, expert_out, result) + ctx.handle_mem = handle_mem + ctx.grad_expert_out = grad_expert_out if grad_expert_out is not None else expert_out + return result + + @staticmethod + def backward(ctx, g_result): # type: ignore[override] + if not g_result.is_contiguous(): + g_result = g_result.contiguous() + grad_combine_in = ctx.grad_expert_out + torch.ops.transformer_engine_ep.combine_bwd(ctx.handle_mem, g_result, grad_combine_in) + return ( + None, # handle_mem + None, # num_local_tokens + None, # hidden_dim + None, # grad_expert_out + grad_combine_in, + ) + + +# Public high-level wrappers + + +# NCCL EP currently only supports bfloat16 payload tensors. +def _require_bf16(name: str, t: torch.Tensor) -> None: + if t.dtype is not torch.bfloat16: + raise NotImplementedError( + f"NCCL EP currently supports only bfloat16 payloads; got {name}.dtype={t.dtype}." + ) + + +def ep_dispatch( + buffer: EpBuffer, + tokens: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + *, + recv_tokens: Optional[torch.Tensor] = None, + recv_topk_weights: Optional[torch.Tensor] = None, +): + """Prepare + dispatch with autograd. topk_idx must be int64. + + recv_tokens / recv_topk_weights are used as-is if supplied, else allocated. + Zero-copy mode requires both to be supplied and symm-mem-backed. + Returns (recv_tokens, recv_topk_weights, token_counts); token_counts is non-diff. + """ + _require_bf16("tokens", tokens) + if topk_weights.dtype is not torch.float32: + raise TypeError( + f"topk_weights must be float32; got dtype={topk_weights.dtype}. " + "Cast with topk_weights.float() before calling." + ) + if buffer.zero_copy and (recv_tokens is None or recv_topk_weights is None): + raise ValueError( + "ep_dispatch: zero-copy mode requires caller-supplied recv_tokens and " + "recv_topk_weights (allocate via symm_mem_alloc)." + ) + if recv_tokens is None: + recv_tokens = torch.empty( + buffer.recv_capacity_per_rank, + buffer.hidden_dim, + dtype=buffer.payload_dtype, + device=buffer.device, + ) + if recv_topk_weights is None: + recv_topk_weights = torch.empty( + buffer.recv_capacity_per_rank, dtype=torch.float32, device=buffer.device + ) + return _EpDispatch.apply( + buffer.handle_mem, + buffer.top_k, + buffer.alignment, + recv_tokens, + recv_topk_weights, + buffer.token_counts, + topk_idx, + tokens, + topk_weights, + ) + + +def ep_combine( + buffer: EpBuffer, + expert_out: torch.Tensor, + *, + num_local_tokens: Optional[int] = None, + grad_expert_out: Optional[torch.Tensor] = None, +): + """Combine with autograd; caller pre-applies topk weighting. + + grad_expert_out is the slot the bwd writes into; if None, expert_out's storage is reused. + Zero-copy mode requires both expert_out and grad_expert_out to be symm-mem-backed. + Result shape is (num_local_tokens, buffer.hidden_dim); defaults to buffer.max_tokens_per_rank rows. + """ + _require_bf16("expert_out", expert_out) + if buffer.zero_copy and grad_expert_out is None: + raise ValueError( + "ep_combine: zero-copy mode requires caller-supplied grad_expert_out " + "(allocate via symm_mem_alloc)." + ) + if num_local_tokens is None: + num_local_tokens = buffer.max_tokens_per_rank + return _EpCombine.apply( + buffer.handle_mem, + num_local_tokens, + buffer.hidden_dim, + grad_expert_out, + expert_out, + )