From 8319f82e0225333a71d4aa4962cc233f04abb4b9 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 9 Jun 2026 14:59:17 -0700 Subject: [PATCH 01/18] EP PyTorch: NCCL EP backend + autograd ops + tests, route zero_copy via cfg.zero_copy Signed-off-by: Phuong Nguyen --- build_tools/pytorch.py | 10 + examples/pytorch/ep/bench/ep_bench.py | 398 ++++++++++ examples/pytorch/ep/bench/run_ep_bench.sh | 72 ++ .../pytorch/ep/bench/run_nccl_ep_bench.sh | 62 ++ examples/pytorch/ep/ep_moe.py | 228 ++++++ examples/pytorch/ep/run_test_ep.sh | 37 + tests/pytorch/distributed/run_ep.py | 338 ++++++++ tests/pytorch/distributed/run_test_ep.sh | 55 ++ tests/pytorch/distributed/test_ep.py | 31 + transformer_engine/pytorch/csrc/extensions.h | 37 + .../pytorch/csrc/extensions/ep.cpp | 331 ++++++++ .../pytorch/csrc/extensions/pybind.cpp | 4 + transformer_engine/pytorch/ep.py | 734 ++++++++++++++++++ 13 files changed, 2337 insertions(+) create mode 100644 examples/pytorch/ep/bench/ep_bench.py create mode 100755 examples/pytorch/ep/bench/run_ep_bench.sh create mode 100755 examples/pytorch/ep/bench/run_nccl_ep_bench.sh create mode 100644 examples/pytorch/ep/ep_moe.py create mode 100755 examples/pytorch/ep/run_test_ep.sh create mode 100644 tests/pytorch/distributed/run_ep.py create mode 100755 tests/pytorch/distributed/run_test_ep.sh create mode 100644 tests/pytorch/distributed/test_ep.py create mode 100644 transformer_engine/pytorch/csrc/extensions/ep.cpp create mode 100644 transformer_engine/pytorch/ep.py diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index e2e6d09c29..ca54e72434 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -77,6 +77,16 @@ def setup_pytorch_extension( setup_mpi_flags(include_dirs, cxx_flags) + # Mirror the NCCL EP gate from setup.py / common CMake. When disabled, the + # ep.cpp source no-ops at the #ifdef boundary; without the define it would + # produce undefined references to nvte_ep_*. + if bool(int(os.getenv("NVTE_BUILD_WITH_NCCL_EP", "1"))): + cxx_flags.append("-DNVTE_WITH_NCCL_EP") + # PyTorch's symm-mem headers gate the NCCL_HAS_SYMMEM_* feature macros on + # USE_NCCL. The EP extension shares the symm-mem NCCL comm with torch, so + # it needs those macros visible. + cxx_flags.append("-DUSE_NCCL") + library_dirs = [] libraries = [] if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))): diff --git a/examples/pytorch/ep/bench/ep_bench.py b/examples/pytorch/ep/bench/ep_bench.py new file mode 100644 index 0000000000..86217b7f91 --- /dev/null +++ b/examples/pytorch/ep/bench/ep_bench.py @@ -0,0 +1,398 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""PyTorch EP perf bench: raw and autograd dispatch/combine on a single EP group. + +One process per GPU; launched via run_ep_bench.sh (torchrun). + +Stages (each timed in its own loop): + - dispatch_raw: _ep_dispatch_raw (no autograd, no prepare) + - ep_dispatch_fwd: ep_dispatch forward only + - ep_dispatch_fwd_bwd: ep_dispatch + backward on 0.5 * ||recv||^2 + - combine_raw: _ep_combine_raw (no autograd) + - ep_combine_fwd: ep_combine forward only + - ep_combine_fwd_bwd: ep_combine + backward + +ep_prepare runs once outside the timed loops. --kineto DIR dumps a Chrome +trace plus a per-kernel summary on rank 0. +""" + +import argparse +import gc +import os +import sys +import time +from contextlib import nullcontext + +import numpy as np +import torch +import torch.distributed as dist + +from transformer_engine.pytorch.ep import ( + EpBuffer, + EpHandle, + ep_bootstrap, + ep_combine, + ep_dispatch, + ep_finalize, + ep_prepare, + _ep_combine_raw, + _ep_dispatch_raw, +) + + +def _parse_args(): + p = argparse.ArgumentParser(description="TE-PyTorch EP perf bench") + p.add_argument("--tokens-per-rank", type=int, default=8192) + p.add_argument("--hidden", type=int, default=7168) + p.add_argument("--top-k", type=int, default=8) + p.add_argument("--num-experts", type=int, default=256) + p.add_argument("--warmup", type=int, default=2) + p.add_argument("--iters", type=int, default=10) + p.add_argument( + "--max-num-sms", + type=int, + default=0, + help="Max SMs for dispatch/combine/preprocess kernels (0 = auto).", + ) + p.add_argument( + "--kineto", + default=None, + help="If set, dump a Kineto Chrome trace + per-kernel summary into this dir (rank 0).", + ) + p.add_argument( + "--cuda-graph", + action="store_true", + default=False, + help=( + "Capture each stage into a CUDA graph and time replay() instead of the eager call. " + "Raw + fwd-only stages use torch.cuda.graph; fwd+bwd stages use " + "torch.cuda.make_graphed_callables to capture forward and backward together." + ), + ) + p.add_argument( + "--mode-label", + default=None, + help="Optional suffix for NVTX range names (e.g. 'fused' / 'unfused').", + ) + return p.parse_args() + + +def _nvtx_funcs(): + """Return push/pop helpers using torch.cuda.nvtx if available.""" + try: + push = torch.cuda.nvtx.range_push + pop = torch.cuda.nvtx.range_pop + return push, pop + except AttributeError: + return lambda _name: None, lambda: None + + +def _device_sm() -> int: + major, minor = torch.cuda.get_device_capability() + return major * 10 + minor + + +def _make_inputs(rank, world_size, T, H, K, E, device): + """Round-robin identity routing + uniform top-k weights.""" + topk_idx = np.empty((T, K), dtype=np.int64) + for t in range(T): + for k in range(K): + topk_idx[t, k] = ((rank * T + t) * K + k) % E + rng = np.random.default_rng(seed=42 + rank) + tokens_np = (rng.standard_normal((T, H), dtype=np.float32) * 0.5).astype(np.float32) + return ( + torch.from_numpy(topk_idx).to(device), + torch.from_numpy(tokens_np).to(device=device, dtype=torch.bfloat16), + torch.full((T, K), 1.0 / K, dtype=torch.float32, device=device), + ) + + +def _time_stage_us(name, fn, iters, nvtx_suffix, push, pop): + """Time fn for iters iterations after one untimed warmup; returns mean us.""" + # Run iters+1 times; drop the first (autotune outlier) and frame NVTX from iter 1. + total_ns = 0 + counted = 0 + for i in range(iters + 1): + if i == 1: + push(f"{name}{nvtx_suffix}") + torch.cuda.synchronize() + t0 = time.perf_counter_ns() + fn() + torch.cuda.synchronize() + dt = time.perf_counter_ns() - t0 + if i == 0: + continue + total_ns += dt + counted += 1 + pop() + return total_ns / 1e3 / counted + + +def main(): + args = _parse_args() + dist.init_process_group(backend="nccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", rank))) + device = torch.device("cuda", torch.cuda.current_device()) + + if _device_sm() < 90: + if rank == 0: + print(f"[ep_bench] SKIPPED: EP requires SM>=90 (got SM{_device_sm()})") + dist.destroy_process_group() + return + if world_size < 4: + if rank == 0: + print(f"[ep_bench] SKIPPED: EP requires >=4 ranks (got {world_size})") + dist.destroy_process_group() + return + + ep_size = world_size + E = args.num_experts + assert E % ep_size == 0, f"num_experts ({E}) must be divisible by ep_size ({ep_size})" + num_local_experts = E // ep_size + T = args.tokens_per_rank + H = args.hidden + K = args.top_k + # Conservative cap: every token could land on every local expert. + recv_pr = world_size * T * K // 2 + if rank == 0: + print( + f"[ep_bench] world={world_size} ep={ep_size} T={T} H={H} K={K} " + f"E={E} (local={num_local_experts}) recv_pr={recv_pr}" + + (f" mode={args.mode_label}" if args.mode_label else ""), + flush=True, + ) + + ep_group = dist.new_group(ranks=list(range(world_size)), backend="nccl") + ep_bootstrap( + ep_group, + num_experts=E, + max_tokens_per_rank=T, + recv_capacity_per_rank=recv_pr, + hidden_dim=H, + max_num_sms=args.max_num_sms, + ) + + topk_idx, tokens_hbm, topk_w_hbm = _make_inputs(rank, world_size, T, H, K, E, device) + + handle = EpHandle( + top_k=K, + max_tokens_per_rank=T, + recv_capacity_per_rank=recv_pr, + hidden_dim=H, + num_local_experts=num_local_experts, + ) + buffer = EpBuffer(handle) + + tokens = tokens_hbm + topk_w = topk_w_hbm + recv_tokens = torch.empty(recv_pr, H, dtype=torch.bfloat16, device=device) + recv_w = torch.empty(recv_pr, dtype=torch.float32, device=device) + + # -- Prepare once outside the timed loops ------------------------------ + ep_prepare(handle, topk_idx) + torch.cuda.synchronize() + + # Pre-dispatch a steady recv_tokens / recv_w so combine stages have valid input. + _ep_dispatch_raw(handle, topk_idx, tokens, topk_w, recv_tokens, recv_w) + torch.cuda.synchronize() + # fp-equivalent stand-in for an MLP output. + expert_out = recv_tokens.clone() + + nvtx_suffix = f"[{args.mode_label}]" if args.mode_label else "" + push, pop = _nvtx_funcs() + + # -- Stage closures ---------------------------------------------------- + # Persistent fwd+bwd inputs (make_graphed_callables needs stable storage). + tokens_p = tokens.detach().clone().requires_grad_(True) + eo_p = recv_tokens.detach().clone().requires_grad_(True) + + # Stand-in callables; the cuda-graph branch below swaps in graphed versions. + fwd_bwd_dispatch_fn = lambda x: ep_dispatch(handle, buffer, x, topk_idx, topk_w)[ # noqa: E731 + 0 + ] + fwd_bwd_combine_fn = lambda eo: ep_combine(handle, buffer, eo) # noqa: E731 + + def _dispatch_raw(): + _ep_dispatch_raw(handle, topk_idx, tokens, topk_w, recv_tokens, recv_w) + + def _combine_raw(): + out_buf = torch.empty(T, H, dtype=torch.bfloat16, device=device) + _ep_combine_raw(handle, expert_out, out_buf) + + def _ep_dispatch_fwd(): + ep_dispatch(handle, buffer, tokens.detach(), topk_idx, topk_w) + + def _ep_dispatch_fwd_bwd(): + tokens_p.grad = None + r = fwd_bwd_dispatch_fn(tokens_p) + (0.5 * (r * r).sum(dtype=torch.float32)).backward() + + def _ep_combine_fwd(): + ep_combine(handle, buffer, recv_tokens) + + def _ep_combine_fwd_bwd(): + eo_p.grad = None + out = fwd_bwd_combine_fn(eo_p) + (0.5 * (out * out).sum(dtype=torch.float32)).backward() + + stages = [ + ("dispatch_raw", _dispatch_raw, True), + ("ep_dispatch_fwd", _ep_dispatch_fwd, True), + ("ep_dispatch_fwd_bwd", _ep_dispatch_fwd_bwd, False), + ("combine_raw", _combine_raw, True), + ("ep_combine_fwd", _ep_combine_fwd, True), + ("ep_combine_fwd_bwd", _ep_combine_fwd_bwd, False), + ] + # Third tuple element: True = direct torch.cuda.graph capture; False = use + # make_graphed_callables (autograd-aware) instead. + + # -- Warmup ----------------------------------------------------------- + for _ in range(args.warmup): + for _name, fn, _capt in stages: + fn() + torch.cuda.synchronize() + + # -- Optional CUDA-graph capture -------------------------------------- + # Capture each capturable stage on a side stream and time .replay() + # instead of the eager call. Outputs allocated inside the + # autograd.Function's forward go through the per-capture private pool + # so addresses stay stable across replays. + captured_runners = {} + if args.cuda_graph: + # Graph fwd+bwd of the autograd-wrapped ops via make_graphed_callables. + class _DispatchMod(torch.nn.Module): + def forward(self, x): + return ep_dispatch(handle, buffer, x, topk_idx, topk_w)[0] + + class _CombineMod(torch.nn.Module): + def forward(self, eo): + return ep_combine(handle, buffer, eo) + + disp_mod = _DispatchMod().cuda() + comb_mod = _CombineMod().cuda() + g_disp, g_comb = torch.cuda.make_graphed_callables( + (disp_mod, comb_mod), + ((tokens_p,), (eo_p,)), + ) + fwd_bwd_dispatch_fn = g_disp + fwd_bwd_combine_fn = g_comb + + # Direct torch.cuda.graph capture for raw + fwd-only stages. + side = torch.cuda.Stream() + side.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(side): + for name, fn, direct_capturable in stages: + if not direct_capturable: + continue + fn() # prime the allocator for stable replay addresses + torch.cuda.synchronize() + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + fn() + captured_runners[name] = g + torch.cuda.current_stream().wait_stream(side) + torch.cuda.synchronize() + + # -- Optional Kineto profiling ---------------------------------------- + kineto_ctx = nullcontext() + if args.kineto and rank == 0: + os.makedirs(args.kineto, exist_ok=True) + kineto_ctx = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=False, + with_stack=False, + ) + + # -- Timed loops ------------------------------------------------------ + results = {} + with kineto_ctx as prof: + for name, fn, _ in stages: + runner = fn + if name in captured_runners: + # Time replay() instead of the eager call. + graph = captured_runners[name] + runner = graph.replay + results[name] = _time_stage_us(name, runner, args.iters, nvtx_suffix, push, pop) + + if rank == 0: + label = f" [{args.mode_label}]" if args.mode_label else "" + print("", flush=True) + print(f"| stage | mean wall (us){label} |", flush=True) + print("|----------------------|---------------:|", flush=True) + for name in ( + "dispatch_raw", + "ep_dispatch_fwd", + "ep_dispatch_fwd_bwd", + "combine_raw", + "ep_combine_fwd", + "ep_combine_fwd_bwd", + ): + print(f"| {name:20s} | {results[name]:14.1f} |", flush=True) + print( + "| (dispatch fwd-raw) |" + f" {results['ep_dispatch_fwd'] - results['dispatch_raw']:14.1f} |", + flush=True, + ) + print( + "| (dispatch bwd-fwd) |" + f" {results['ep_dispatch_fwd_bwd'] - results['ep_dispatch_fwd']:14.1f} |", + flush=True, + ) + print( + "| (combine fwd-raw) |" + f" {results['ep_combine_fwd'] - results['combine_raw']:14.1f} |", + flush=True, + ) + print( + "| (combine bwd-fwd) |" + f" {results['ep_combine_fwd_bwd'] - results['ep_combine_fwd']:14.1f} |", + flush=True, + ) + print("", flush=True) + + if args.kineto and rank == 0 and prof is not None: + trace_path = os.path.join(args.kineto, "ep_bench_trace.json") + prof.export_chrome_trace(trace_path) + print(f"[ep_bench] kineto trace: {trace_path}", flush=True) + print( + prof.key_averages().table(sort_by="cuda_time_total", row_limit=30), + flush=True, + ) + kern_csv = os.path.join(args.kineto, "ep_bench_kernels.csv") + with open(kern_csv, "w") as f: + f.write("name,cuda_time_us,cpu_time_us,count\n") + for evt in prof.key_averages(): + if evt.device_time_total == 0 and evt.cpu_time_total == 0: + continue + f.write(f"{evt.key},{evt.device_time_total},{evt.cpu_time_total},{evt.count}\n") + print(f"[ep_bench] per-kernel CSV: {kern_csv}", flush=True) + + # Captured CUDA graphs (when --cuda-graph) hold references to NCCL EP + # handles and per-pool streams; drop them and sync before ep_finalize, + # otherwise the post-finalize dist.barrier can deadlock against pending + # graph state. + torch.cuda.synchronize() + if args.cuda_graph: + fwd_bwd_dispatch_fn = None + fwd_bwd_combine_fn = None + captured_runners.clear() + del g_disp, g_comb, disp_mod, comb_mod + del tokens_p, eo_p, buffer, handle, recv_tokens, recv_w, tokens, topk_w, expert_out + gc.collect() + torch.cuda.synchronize() + # Release NCCL EP's borrowed comm before torch destroys it. + ep_finalize() + dist.barrier() + dist.destroy_process_group() + sys.stdout.flush() + sys.stderr.flush() + + +if __name__ == "__main__": + main() diff --git a/examples/pytorch/ep/bench/run_ep_bench.sh b/examples/pytorch/ep/bench/run_ep_bench.sh new file mode 100755 index 0000000000..fefecd7fa9 --- /dev/null +++ b/examples/pytorch/ep/bench/run_ep_bench.sh @@ -0,0 +1,72 @@ +#!/usr/bin/env bash +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +# +# Launcher for examples/pytorch/ep/bench/ep_bench.py. +# Examples: +# bash run_ep_bench.sh # plain run, stdout only +# bash run_ep_bench.sh --cuda-graph # capture + replay each stage as a CUDA graph +# bash run_ep_bench.sh --kineto # Chrome trace + per-kernel CSV (rank 0) +# bash run_ep_bench.sh --nsys # nsys profile on rank 0 -> results/pyt_nsys.nsys-rep + +set -uo pipefail + +NSYS=0; KINETO=0; CGRAPH=0 +for a in "$@"; do + case "$a" in + --nsys) NSYS=1 ;; + --kineto) KINETO=1 ;; + --cuda-graph) CGRAPH=1 ;; + *) echo "unknown arg: $a" >&2; exit 2 ;; + esac +done +if [ "${NSYS}" -eq 1 ] && [ "${KINETO}" -eq 1 ]; then + echo "--nsys and --kineto both attach CUPTI; pick one." >&2; exit 2 +fi + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TE_REPO_ROOT="$(cd "${SCRIPT_DIR}/../../../.." && pwd)" +RESULTS="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS}" +export PYTHONPATH="${TE_REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +DETECTED_GPUS=$(nvidia-smi -L 2>/dev/null | wc -l) +NUM_GPUS="${NUM_GPUS:-${DETECTED_GPUS}}" +if [ "${NUM_GPUS}" -lt 4 ]; then + echo "EP bench requires >=4 GPUs (found ${NUM_GPUS}); SKIPPING."; exit 0 +fi +if [ "${NUM_GPUS}" -gt 8 ]; then NUM_GPUS=8; fi + +: "${TIMEOUT_S:=1800}" +: "${NCCL_EP_JIT_CACHE_DIR:=${TMPDIR:-/tmp}/nccl_ep_jit_cache_$(id -u)}" +export NCCL_EP_JIT_CACHE_DIR +mkdir -p "${NCCL_EP_JIT_CACHE_DIR}" + +EXTRA_ARGS=() +TAG="pyt" +[ "${CGRAPH}" -eq 1 ] && EXTRA_ARGS+=(--cuda-graph) && TAG="${TAG}_cg" +if [ "${KINETO}" -eq 1 ]; then + EXTRA_ARGS+=(--kineto "${RESULTS}/kineto_${TAG}") +fi + +EP_BENCH_EXTRA_FLAGS="${EP_BENCH_EXTRA_FLAGS:-}" +LAUNCH=(torchrun --standalone --nnodes=1 --nproc-per-node="${NUM_GPUS}" + "${SCRIPT_DIR}/ep_bench.py" "${EXTRA_ARGS[@]}" ${EP_BENCH_EXTRA_FLAGS}) + +if [ "${NSYS}" -eq 1 ]; then + NSYS_CMD=(nsys profile + --output "${RESULTS}/pyt_${TAG}_nsys" + --force-overwrite=true + --trace=cuda,nvtx + --gpu-metrics-devices=none + --cuda-um-cpu-page-faults=false + --cuda-um-gpu-page-faults=false) + echo "[run_ep_bench] launching with nsys (results/${TAG}_nsys.nsys-rep)" + timeout --foreground --signal=TERM "${TIMEOUT_S}" "${NSYS_CMD[@]}" "${LAUNCH[@]}" + RC=$? +else + timeout --foreground --signal=TERM "${TIMEOUT_S}" "${LAUNCH[@]}" + RC=$? +fi +exit $RC diff --git a/examples/pytorch/ep/bench/run_nccl_ep_bench.sh b/examples/pytorch/ep/bench/run_nccl_ep_bench.sh new file mode 100755 index 0000000000..8f6da04a00 --- /dev/null +++ b/examples/pytorch/ep/bench/run_nccl_ep_bench.sh @@ -0,0 +1,62 @@ +#!/usr/bin/env bash +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +# +# Launcher for the native NCCL EP ``ep_bench`` (baseline for PyTorch comparison). +# Usage: +# bash run_nccl_ep_bench.sh # plain run, stdout only +# bash run_nccl_ep_bench.sh --nsys # nsys → results/nccl_ep_nsys.nsys-rep + +set -uo pipefail + +NSYS=0 +for a in "$@"; do + case "$a" in + --nsys) NSYS=1 ;; + *) echo "unknown arg: $a" >&2; exit 2 ;; + esac +done + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TE_REPO_ROOT="$(cd "${SCRIPT_DIR}/../../../.." && pwd)" +RESULTS="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS}" + +BIN="${TE_REPO_ROOT}/3rdparty/nccl/build/test/nccl_ep/ep_bench" +LIB="${TE_REPO_ROOT}/3rdparty/nccl/build/lib" +[ -x "${BIN}" ] || { echo "ep_bench not built at ${BIN}" >&2; exit 2; } + +NUM_GPUS=$(nvidia-smi -L 2>/dev/null | wc -l) +if [ "${NUM_GPUS}" -lt 4 ]; then + echo "NCCL EP bench requires >=4 GPUs (found ${NUM_GPUS}); SKIPPING."; exit 0 +fi +if [ "${NUM_GPUS}" -gt 8 ]; then NUM_GPUS=8; fi + +if [ "${NSYS}" -eq 1 ]; then + ITERS=10 +else + ITERS=50 +fi +ARGS=(--algorithm ht --layout em --tokens 2048 --hidden 7168 --top-k 8 + --experts 256 --warmup 5 --iters "${ITERS}") +[ "${NSYS}" -eq 1 ] && ARGS+=(--profile) # enables NVTX ranges + cudaProfilerStart/Stop + +CMD=(/usr/local/mpi/bin/mpirun --allow-run-as-root --oversubscribe -np "${NUM_GPUS}" + -x LD_LIBRARY_PATH="${LIB}:${LD_LIBRARY_PATH:-}" + "${BIN}" "${ARGS[@]}") + +if [ "${NSYS}" -eq 1 ]; then + CMD=(nsys profile + --output "${RESULTS}/nccl_ep_nsys" + --force-overwrite=true + --capture-range=cudaProfilerApi + --capture-range-end=stop + --trace=cuda,nvtx,osrt + "${CMD[@]}") +fi + +[ "${NSYS}" -eq 1 ] && SUFFIX="_nsys" || SUFFIX="" +LOG="${RESULTS}/stdout_nccl_ep${SUFFIX}.txt" +"${CMD[@]}" 2>&1 | tee "${LOG}" +echo "Done. Log: ${LOG}" diff --git a/examples/pytorch/ep/ep_moe.py b/examples/pytorch/ep/ep_moe.py new file mode 100644 index 0000000000..934d88d8c7 --- /dev/null +++ b/examples/pytorch/ep/ep_moe.py @@ -0,0 +1,228 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""End-to-end MoE example: dispatch -> batched expert linear -> combine, fwd + bwd. + +One process per GPU; launched via run_test_ep.sh (torchrun). +""" + +import argparse +import os +import sys + +import numpy as np +import torch +import torch.distributed as dist + +from transformer_engine.pytorch.ep import ( + EpHandle, + EpBuffer, + ep_scope, + ep_dispatch, + ep_combine, +) + + +def _parse_args(): + p = argparse.ArgumentParser(description="TE-PyTorch EP MoE example (fwd + bwd)") + p.add_argument("--num-tokens", type=int, default=8, help="Per-rank token count.") + p.add_argument("--top-k", type=int, default=2) + p.add_argument("--hidden", type=int, default=32) + p.add_argument("--hidden-out", type=int, default=32) + p.add_argument("--num-experts", type=int, default=None) + p.add_argument("--check", action="store_true", default=True) + p.add_argument( + "--benchmark", + action="store_true", + help="Time fwd over HBM buffers.", + ) + p.add_argument("--benchmark-iters", type=int, default=20) + p.add_argument("--benchmark-warmup", type=int, default=5) + return p.parse_args() + + +def _make_routing(rank, T, K, E, num_local_experts): + """Deterministic routing: topk_idx[t, k] = (rank*NLE + t*K + k) % E.""" + topk_idx = np.empty((T, K), dtype=np.int64) + for t in range(T): + for k in range(K): + topk_idx[t, k] = (rank * num_local_experts + t * K + k) % E + return topk_idx + + +def _batched_expert_linear(recv_tokens, kernels, num_local_experts): + """Per-expert linear via bmm; ``recv_pr // num_local_experts`` slots per expert.""" + recv_pr, _H = recv_tokens.shape + H_out = kernels.shape[-1] + slots_per_expert = recv_pr // num_local_experts + grouped = recv_tokens.view(num_local_experts, slots_per_expert, recv_tokens.shape[-1]) + out = torch.bmm(grouped, kernels.to(grouped.dtype)) + return out.view(recv_pr, H_out) + + +def _reference_moe(tokens, topk_idx, topk_w, kernels): + T, K = topk_idx.shape + H_out = kernels.shape[-1] + out = np.zeros((T, H_out), dtype=np.float32) + for t in range(T): + tok = tokens[t].astype(np.float32) + for k in range(K): + e = int(topk_idx[t, k]) + out[t] += float(topk_w[t, k]) * (tok @ kernels[e].astype(np.float32)) + return out + + +def _reference_grad(tokens, topk_idx, topk_w, kernels): + T, K = topk_idx.shape + H = tokens.shape[-1] + ref_out = _reference_moe(tokens, topk_idx, topk_w, kernels) + grad = np.zeros((T, H), dtype=np.float32) + for t in range(T): + mixed = np.zeros_like(kernels[0]) + for k in range(K): + mixed = mixed + float(topk_w[t, k]) * kernels[int(topk_idx[t, k])] + grad[t] = ref_out[t] @ mixed.T + return ref_out, grad + + +def main(): + args = _parse_args() + + dist.init_process_group(backend="nccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", rank))) + device = torch.device("cuda", torch.cuda.current_device()) + + major, minor = torch.cuda.get_device_capability() + if major * 10 + minor < 90: + if rank == 0: + print(f"[ep_moe] SKIPPED: EP requires SM>=90 (got SM{major}{minor})") + dist.destroy_process_group() + return + + if world_size < 4: + if rank == 0: + print(f"[ep_moe] SKIPPED: EP requires >= 4 ranks (got {world_size})") + dist.destroy_process_group() + return + + ep_size = world_size + num_experts = args.num_experts if args.num_experts is not None else world_size + assert num_experts % ep_size == 0 + num_local_experts = num_experts // ep_size + T = args.num_tokens + recv_pr = ep_size * T * args.top_k + + ep_group = dist.new_group(ranks=list(range(world_size)), backend="nccl") + with ep_scope( + ep_group, + num_experts=num_experts, + max_tokens_per_rank=T, + recv_capacity_per_rank=recv_pr, + hidden_dim=args.hidden, + ): + _run_layer(args, rank, world_size, ep_size, num_experts, num_local_experts, T, recv_pr, device) + dist.destroy_process_group() + + +def _run_layer(args, rank, world_size, ep_size, num_experts, num_local_experts, T, recv_pr, device): + rng = np.random.default_rng(seed=42 + rank) + tokens_np = (rng.standard_normal((T, args.hidden), dtype=np.float32) * 0.5).astype(np.float32) + topk_idx_np = _make_routing(rank, T, args.top_k, num_experts, num_local_experts) + w_np = np.full((T, args.top_k), 1.0 / args.top_k, dtype=np.float32) + # Same seed across ranks -> identical kernel array everywhere. + kr = np.random.default_rng(seed=42) + kernels_np = ( + kr.standard_normal((num_experts, args.hidden, args.hidden_out), dtype=np.float32) + * (1.0 / np.sqrt(args.hidden)) + ).astype(np.float32) + + tokens = ( + torch.from_numpy(tokens_np).to(device=device, dtype=torch.bfloat16).requires_grad_(True) + ) + topk_idx = torch.from_numpy(topk_idx_np).to(device) + topk_w = torch.from_numpy(w_np).to(device) + kernels_local = torch.from_numpy( + kernels_np[rank * num_local_experts : (rank + 1) * num_local_experts] + ).to(device=device, dtype=torch.bfloat16) + + handle = EpHandle( + top_k=args.top_k, + max_tokens_per_rank=T, + recv_capacity_per_rank=recv_pr, + hidden_dim=args.hidden, + num_local_experts=num_local_experts, + ) + buffer = EpBuffer(handle) + + recv_t, recv_w_out, _tc = ep_dispatch(handle, buffer, tokens, topk_idx, topk_w) + expert_out = _batched_expert_linear(recv_t, kernels_local, num_local_experts) + # Apply per-slot topk weighting before combine. + expert_out = expert_out * recv_w_out.unsqueeze(-1).to(expert_out.dtype) + out = ep_combine(handle, buffer, expert_out) + + loss = 0.5 * (out.float() ** 2).sum() + loss.backward() + torch.cuda.synchronize() + + if rank == 0: + print( + f"[ep_moe] loss={loss.item():.4f} grad_tokens.shape={tuple(tokens.grad.shape)} " + f"ep={ep_size} num_experts={num_experts} recv_pr={recv_pr}" + ) + + if args.benchmark: + # Time dispatch + expert + combine over HBM buffers. + import time + + torch.cuda.synchronize() + dist.barrier() + for _ in range(args.benchmark_warmup): + rt, rw, _tc = ep_dispatch(handle, buffer, tokens.detach(), topk_idx, topk_w) + eo = _batched_expert_linear(rt, kernels_local, num_local_experts) + eo = eo * rw.unsqueeze(-1).to(eo.dtype) + ep_combine(handle, buffer, eo) + torch.cuda.synchronize() + dist.barrier() + t0 = time.perf_counter() + for _ in range(args.benchmark_iters): + rt, rw, _tc = ep_dispatch(handle, buffer, tokens.detach(), topk_idx, topk_w) + eo = _batched_expert_linear(rt, kernels_local, num_local_experts) + eo = eo * rw.unsqueeze(-1).to(eo.dtype) + ep_combine(handle, buffer, eo) + torch.cuda.synchronize() + dt_ms = (time.perf_counter() - t0) * 1000.0 / args.benchmark_iters + if rank == 0: + print( + f"[ep_moe --benchmark] HBM: {dt_ms:.3f} ms/iter " + f"(iters={args.benchmark_iters})" + ) + + if args.check: + # All-gather inputs/outputs/grads for a global reference comparison. + global_tokens = [torch.empty_like(tokens) for _ in range(world_size)] + global_topk_idx = [torch.empty_like(topk_idx) for _ in range(world_size)] + global_topk_w = [torch.empty_like(topk_w) for _ in range(world_size)] + global_out = [torch.empty_like(out) for _ in range(world_size)] + global_grad = [torch.empty_like(tokens.grad) for _ in range(world_size)] + dist.all_gather(global_tokens, tokens.detach()) + dist.all_gather(global_topk_idx, topk_idx) + dist.all_gather(global_topk_w, topk_w) + dist.all_gather(global_out, out.detach()) + dist.all_gather(global_grad, tokens.grad) + if rank == 0: + all_tokens = torch.cat(global_tokens).float().cpu().numpy() + all_idx = torch.cat(global_topk_idx).cpu().numpy() + all_w = torch.cat(global_topk_w).cpu().numpy() + all_out = torch.cat(global_out).float().cpu().numpy() + all_grad = torch.cat(global_grad).float().cpu().numpy() + ref_out, ref_grad = _reference_grad(all_tokens, all_idx, all_w, kernels_np) + np.testing.assert_allclose(all_out, ref_out, rtol=5e-2, atol=5e-2) + np.testing.assert_allclose(all_grad, ref_grad, rtol=5e-2, atol=5e-2) + print(f"[ep_moe] --check PASSED (ref_out.sum()={float(ref_out.sum()):.4f})") + + +if __name__ == "__main__": + main() + sys.exit(0) diff --git a/examples/pytorch/ep/run_test_ep.sh b/examples/pytorch/ep/run_test_ep.sh new file mode 100755 index 0000000000..13b41f4cb2 --- /dev/null +++ b/examples/pytorch/ep/run_test_ep.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +set -uo pipefail + +DETECTED_GPUS=$(nvidia-smi -L 2>/dev/null | wc -l) +NUM_GPUS="${NUM_GPUS:-${DETECTED_GPUS}}" +if [ "${NUM_GPUS}" -lt 4 ]; then + echo "EP requires >= 4 GPUs (found ${NUM_GPUS}); SKIPPING." + exit 0 +fi +if [ "${NUM_GPUS}" -gt 8 ]; then NUM_GPUS=8; fi + +: ${TE_PATH:=/opt/transformerengine} +: ${TEST_TIMEOUT_S:=120} + +SCRIPT="${TE_PATH}/examples/pytorch/ep/ep_moe.py" +export PYTHONPATH="${TE_PATH}${PYTHONPATH:+:${PYTHONPATH}}" + +# Stage JIT cubins on tmpfs for fast iteration. +: ${NCCL_EP_JIT_CACHE_DIR:="${TMPDIR:-/tmp}/nccl_ep_jit_cache_$(id -u)"} +export NCCL_EP_JIT_CACHE_DIR +mkdir -p "$NCCL_EP_JIT_CACHE_DIR" + +echo "*** Executing ep_moe.py across ${NUM_GPUS} GPUs (timeout=${TEST_TIMEOUT_S}s) ***" +timeout --foreground --signal=KILL "${TEST_TIMEOUT_S}" \ + torchrun --standalone --nnodes=1 --nproc-per-node="${NUM_GPUS}" \ + "${SCRIPT}" --check 2>&1 | tee stdout_ep_moe.txt +RC=${PIPESTATUS[0]} + +RET=0 +if [ "${RC}" -ne 0 ]; then RET=1; fi +if grep -qE "(^|]:)FAILED|(^|]:)Traceback" stdout_ep_moe.txt; then RET=1; fi +rm -f stdout_ep_moe.txt +exit $RET diff --git a/tests/pytorch/distributed/run_ep.py b/tests/pytorch/distributed/run_ep.py new file mode 100644 index 0000000000..7f74a454aa --- /dev/null +++ b/tests/pytorch/distributed/run_ep.py @@ -0,0 +1,338 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Multi-process PyTorch EP tests, launched via torchrun (one process per GPU).""" + +import os +import sys +import unittest + +import numpy as np +import torch +import torch.distributed as dist + +from transformer_engine.pytorch.ep import ( + EpHandle, + EpBuffer, + ep_bootstrap, + ep_finalize, + ep_prepare, + ep_dispatch, + ep_combine, + _ep_combine_raw, + _ep_dispatch_raw, +) + +# Must come after the transformer_engine import so libtransformer_engine.so is loaded. +import transformer_engine_torch as tex # noqa: F401 + + +NUM_LOCAL_EXPERTS = 2 +HIDDEN_DIM = 32 +TOP_K = 2 +TOKENS_PER_RANK = 4 + + +def _device_sm() -> int: + major, minor = torch.cuda.get_device_capability() + return major * 10 + minor + + +def _build_ep_group(): + """EP group spanning all ranks of the default PG.""" + world_pg = dist.distributed_c10d._get_default_group() + ranks = list(range(world_pg.size())) + return dist.new_group(ranks=ranks, backend="nccl") + + +def _make_identity_inputs(rank, ep_size, device="cuda"): + """Per-rank identity routing + uniform weights so combine matches tokens.""" + T = TOKENS_PER_RANK + E = ep_size * NUM_LOCAL_EXPERTS + topk_idx = np.empty((T, TOP_K), dtype=np.int64) + base = rank * T + for t in range(T): + for k in range(TOP_K): + topk_idx[t, k] = ((base + t) * TOP_K + k) % E + tokens_np = np.linspace( + 0.1 + rank * 0.01, 0.9 + rank * 0.01, T * HIDDEN_DIM, dtype=np.float32 + ).reshape(T, HIDDEN_DIM) + topk_weights = np.full((T, TOP_K), 1.0 / TOP_K, dtype=np.float32) + return ( + torch.from_numpy(topk_idx).to(device), + torch.from_numpy(tokens_np).to(device=device, dtype=torch.bfloat16), + torch.from_numpy(topk_weights).to(device), + ) + + +class _Cfg: + rank: int + world_size: int + ep_size: int + num_experts: int + recv_capacity_per_rank: int + device: torch.device + + +def _make_cfg() -> _Cfg: + cfg = _Cfg() + cfg.rank = dist.get_rank() + cfg.world_size = dist.get_world_size() + cfg.ep_size = cfg.world_size + cfg.num_experts = NUM_LOCAL_EXPERTS * cfg.ep_size + T = TOKENS_PER_RANK + active = min(cfg.num_experts, T * cfg.ep_size * TOP_K) + overconc = cfg.num_experts // active + cfg.recv_capacity_per_rank = NUM_LOCAL_EXPERTS * max(T * cfg.ep_size * TOP_K, 16) * overconc * 2 + cfg.device = torch.device("cuda", torch.cuda.current_device()) + return cfg + + +class TestEP(unittest.TestCase): + cfg: _Cfg + ep_group: dist.ProcessGroup + + @classmethod + def setUpClass(cls): + if _device_sm() < 90: + raise unittest.SkipTest(f"NCCL EP requires SM>=90 (got SM{_device_sm()})") + cls.cfg = _make_cfg() + cls.ep_group = _build_ep_group() + ep_bootstrap( + cls.ep_group, + num_experts=cls.cfg.num_experts, + max_tokens_per_rank=TOKENS_PER_RANK, + recv_capacity_per_rank=cls.cfg.recv_capacity_per_rank, + hidden_dim=HIDDEN_DIM, + zero_copy=True, + ) + + def _make_handle(self, alignment=0, top_k=TOP_K): + return EpHandle( + top_k=top_k, + max_tokens_per_rank=TOKENS_PER_RANK, + recv_capacity_per_rank=self.cfg.recv_capacity_per_rank, + hidden_dim=HIDDEN_DIM, + num_local_experts=NUM_LOCAL_EXPERTS, + alignment=alignment, + ) + + def _make_buffers(self, dtype=torch.bfloat16): + """Allocate raw recv buffers + token_counts for the primitive tests.""" + rc = self.cfg.recv_capacity_per_rank + return ( + torch.empty(rc, HIDDEN_DIM, dtype=dtype, device=self.cfg.device), + torch.empty(rc, dtype=torch.float32, device=self.cfg.device), + torch.empty(NUM_LOCAL_EXPERTS, dtype=torch.int32, device=self.cfg.device), + ) + + def _make_ep_buffer(self, handle): + return EpBuffer(handle) + + @staticmethod + def _weighted(recv_tokens, recv_w): + """fp32 per-slot weighting + cast back; matches the upstream combine input.""" + mask = (recv_w != 0).to(torch.float32).unsqueeze(-1) + return (recv_tokens.float() * recv_w.unsqueeze(-1).float() * mask).to(recv_tokens.dtype) + + def _moe_step(self, handle, buffer, topk_idx, tokens, w): + recv_t, recv_w_out, _tc = ep_dispatch(handle, buffer, tokens, topk_idx, w) + eo = self._weighted(recv_t, recv_w_out) + return ep_combine(handle, buffer, eo) + + # Prepare + + def test_primitive_prepare(self): + handle = self._make_handle() + topk_idx, _toks, _w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + token_counts = ep_prepare(handle, topk_idx) + torch.cuda.synchronize() + self.assertEqual(token_counts.shape, (NUM_LOCAL_EXPERTS,)) + local = int(token_counts.sum().item()) + total = torch.tensor([local], dtype=torch.int64, device=self.cfg.device) + dist.all_reduce(total, op=dist.ReduceOp.SUM, group=self.ep_group) + self.assertEqual(int(total.item()), self.cfg.world_size * TOKENS_PER_RANK * TOP_K) + + # Identity round-trip via raw primitives + + def test_primitive_dispatch_combine_identity(self): + handle = self._make_handle() + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + recv_tokens, recv_w, _ = self._make_buffers() + ep_prepare(handle, topk_idx) + _ep_dispatch_raw(handle, topk_idx, tokens, w, recv_tokens, recv_w) + result = torch.empty_like(tokens) + _ep_combine_raw(handle, self._weighted(recv_tokens, recv_w), result) + torch.cuda.synchronize() + torch.testing.assert_close(result.float(), tokens.float(), atol=5e-2, rtol=5e-2) + + # Autograd + + def test_dispatch_fwd_bwd(self): + """0.5*||recv_tokens||^2 ; grad_tokens equals TOP_K * tokens.""" + handle = self._make_handle() + buffer = self._make_ep_buffer(handle) + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + tokens_p = tokens.detach().clone().requires_grad_(True) + recv_t, _recv_w, _tc = ep_dispatch(handle, buffer, tokens_p, topk_idx, w) + loss = 0.5 * (recv_t.float() ** 2).sum() + loss.backward() + torch.cuda.synchronize() + torch.testing.assert_close( + tokens_p.grad.float(), tokens.float() * float(TOP_K), atol=5e-2, rtol=5e-2 + ) + + def test_combine_fwd_bwd(self): + """Full dispatch + combine fwd+bwd; identity inputs round-trip.""" + handle = self._make_handle() + buffer = self._make_ep_buffer(handle) + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + tokens_p = tokens.detach().clone().requires_grad_(True) + out = self._moe_step(handle, buffer, topk_idx, tokens_p, w) + loss = 0.5 * (out.float() ** 2).sum() + loss.backward() + torch.cuda.synchronize() + torch.testing.assert_close(out.float(), tokens.float(), atol=5e-2, rtol=5e-2) + torch.testing.assert_close(tokens_p.grad.float(), tokens.float(), atol=5e-2, rtol=5e-2) + + # Multi-iter stability + + def test_dispatch_fwd_bwd_multiple_iterations(self): + """5 fwd+bwd iters on the same EpHandle + EpBuffer must be bit-stable.""" + handle = self._make_handle() + buffer = self._make_ep_buffer(handle) + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + + def one_step(): + tokens_p = tokens.detach().clone().requires_grad_(True) + out = self._moe_step(handle, buffer, topk_idx, tokens_p, w) + loss = 0.5 * (out.float() ** 2).sum() + loss.backward() + return out.detach().clone(), tokens_p.grad.detach().clone() + + out_ref, grad_ref = one_step() + torch.cuda.synchronize() + for _ in range(4): + out_i, grad_i = one_step() + torch.cuda.synchronize() + torch.testing.assert_close(out_i, out_ref, atol=0, rtol=0) + torch.testing.assert_close(grad_i, grad_ref, atol=0, rtol=0) + + # CUDA graph + + def test_cuda_graph_capture(self): + """Capture raw dispatch+combine into a CUDA graph; replay must be bit-stable.""" + handle = self._make_handle() + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + recv_tokens, recv_w, _ = self._make_buffers() + result = torch.empty_like(tokens) + + def step(): + ep_prepare(handle, topk_idx) + _ep_dispatch_raw(handle, topk_idx, tokens, w, recv_tokens, recv_w) + _ep_combine_raw(handle, self._weighted(recv_tokens, recv_w), result) + + for _ in range(3): + step() + torch.cuda.synchronize() + + # Routing is fixed per layer; prepare runs once before capture. + ep_prepare(handle, topk_idx) + torch.cuda.synchronize() + + graph = torch.cuda.CUDAGraph() + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + with torch.cuda.graph(graph): + _ep_dispatch_raw(handle, topk_idx, tokens, w, recv_tokens, recv_w) + _ep_combine_raw(handle, self._weighted(recv_tokens, recv_w), result) + torch.cuda.current_stream().wait_stream(s) + torch.cuda.synchronize() + + ref = result.clone() + for _ in range(5): + graph.replay() + torch.cuda.synchronize() + torch.testing.assert_close(result.float(), ref.float(), atol=0, rtol=0) + + # PP-1F1B handle isolation + + def test_pp_1f1b_two_handles(self): + """PP-1F1B interleave (F0 F1 B0 F2 B1 B2) over 3 per-microbatch handles.""" + T, H = TOKENS_PER_RANK, HIDDEN_DIM + idx, _toks, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + scales = (0.13, 0.41, 0.77) + handles, buffers, tokens, tokens_p = [], [], [], [] + for s in scales: + h = self._make_handle() + handles.append(h) + buffers.append(self._make_ep_buffer(h)) + t = torch.full( + (T, H), s + self.cfg.rank * 0.01, dtype=torch.bfloat16, device=self.cfg.device + ) + tokens.append(t) + tokens_p.append(t.detach().clone().requires_grad_(True)) + + recv = [None, None, None] + + def fwd(k): + recv[k], _, _ = ep_dispatch(handles[k], buffers[k], tokens_p[k], idx, w) + + def bwd(k): + (0.5 * (recv[k].float() ** 2).sum()).backward() + recv[k] = None + + fwd(0) + fwd(1) + bwd(0) + fwd(2) + bwd(1) + bwd(2) + torch.cuda.synchronize() + for k in range(3): + torch.testing.assert_close( + tokens_p[k].grad.float(), + tokens[k].float() * float(TOP_K), + atol=5e-2, + rtol=5e-2, + ) + + # Input validation + + def test_topk_int32_raises_clear_error(self): + handle = self._make_handle() + topk_idx_int32 = torch.zeros( + TOKENS_PER_RANK, TOP_K, dtype=torch.int32, device=self.cfg.device + ) + with self.assertRaises(RuntimeError) as cm: + ep_prepare(handle, topk_idx_int32) + msg = str(cm.exception) + self.assertIn("topk_idx", msg) + self.assertIn(".long()", msg) + + +def _init_distributed(): + dist.init_process_group(backend="nccl") + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + try: + from torch.distributed import _symmetric_memory as _symm_mem + + _symm_mem.set_backend("NCCL") + except (ImportError, RuntimeError): + pass + + +if __name__ == "__main__": + _init_distributed() + loader = unittest.TestLoader() + name_filter = os.environ.get("NVTE_EP_TEST_FILTER") + if name_filter: + loader.testMethodPrefix = name_filter + suite = loader.loadTestsFromTestCase(TestEP) + runner = unittest.TextTestRunner(stream=sys.stdout, verbosity=2) + result = runner.run(suite) + dist.barrier() + ep_finalize() + dist.destroy_process_group() + sys.exit(0 if result.wasSuccessful() else 1) diff --git a/tests/pytorch/distributed/run_test_ep.sh b/tests/pytorch/distributed/run_test_ep.sh new file mode 100755 index 0000000000..92d63cff7e --- /dev/null +++ b/tests/pytorch/distributed/run_test_ep.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +# +# Launcher for tests/pytorch/distributed/run_ep.py. Auto-detects GPU count. +# Short timeout by default to surface hangs early. + +set -uo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TE_REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +export PYTHONPATH="${TE_REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +DETECTED_GPUS=$(nvidia-smi -L 2>/dev/null | wc -l) +if [ "${DETECTED_GPUS}" -lt 4 ]; then + echo "EP requires >= 4 GPUs (found ${DETECTED_GPUS}); SKIPPING." + exit 0 +fi +NUM_RANKS="${NVTE_TEST_EP_NUM_RANKS:-${DETECTED_GPUS}}" +if [ "${NUM_RANKS}" -gt 8 ]; then NUM_RANKS=8; fi + +# Short timeout to detect hangs early. +TEST_TIMEOUT_S="${TEST_TIMEOUT_S:-120}" + +# Stage NCCL EP JIT cubins on tmpfs to keep iteration fast. +: ${NCCL_EP_JIT_CACHE_DIR:="${TMPDIR:-/tmp}/nccl_ep_jit_cache_$(id -u)"} +export NCCL_EP_JIT_CACHE_DIR +mkdir -p "$NCCL_EP_JIT_CACHE_DIR" + +SCRIPT="${SCRIPT_DIR}/run_ep.py" +echo "=== Running ${SCRIPT} on ${NUM_RANKS} GPUs (timeout=${TEST_TIMEOUT_S}s) ===" + +# setsid + kill-after so SIGKILL takes down the whole process group, not just torchrun. +setsid timeout --foreground --kill-after=10 --signal=TERM "${TEST_TIMEOUT_S}" \ + torchrun --standalone --nnodes=1 --nproc-per-node="${NUM_RANKS}" \ + "${SCRIPT}" 2>&1 | tee stdout_ep.txt +RC=${PIPESTATUS[0]} +pkill -9 -f "tests/pytorch/distributed/run_ep.py" 2>/dev/null || true + +RET=0 +if [ "${RC}" -ne 0 ]; then + echo "torchrun exited with ${RC}" + RET=1 +fi +# Match unittest failure markers and unhandled Python tracebacks; torchrun +# prefixes per-rank stderr with "[rankN]:" so don't anchor at column 0. +if grep -qE "(^|]:)FAILED|(^|]:)Traceback" stdout_ep.txt; then RET=1; fi +if ! grep -qE "Ran [0-9]+ test|^OK$" stdout_ep.txt; then + echo "ERROR: no test summary — likely hang or early crash" + RET=1 +fi + +if [ -z "${KEEP_EP_LOGS:-}" ]; then rm -f stdout_ep.txt; fi +exit $RET diff --git a/tests/pytorch/distributed/test_ep.py b/tests/pytorch/distributed/test_ep.py new file mode 100644 index 0000000000..81eef9a3c1 --- /dev/null +++ b/tests/pytorch/distributed/test_ep.py @@ -0,0 +1,31 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Pytest driver — spawns run_ep.py under torchrun and asserts the suite passed.""" + +import os +import subprocess +from pathlib import Path + +import pytest +import torch + +TEST_ROOT = Path(__file__).parent.resolve() +WORKER = TEST_ROOT / "run_ep.py" +LAUNCHER = TEST_ROOT / "run_test_ep.sh" + + +@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="EP requires >= 4 GPUs") +def test_multi_process_ep(): + """Launch the EP unit-test suite across all visible GPUs. + + Short timeout so a hang on any rank surfaces fast rather than burning CI time. + """ + timeout_s = int(os.environ.get("NVTE_TEST_EP_TIMEOUT_S", "180")) + proc = subprocess.run( + ["bash", str(LAUNCHER)], + env={**os.environ, "KEEP_EP_LOGS": "1", "TEST_TIMEOUT_S": str(timeout_s)}, + timeout=timeout_s + 30, + check=False, + ) + assert proc.returncode == 0, f"EP test suite failed (rc={proc.returncode})" diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 2b4f899e1d..cf1481a1b6 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -9,6 +9,7 @@ #include +#include #include #include #include @@ -647,6 +648,42 @@ void inplace_multi_tensor_swizzle_scales_for_gemm_unchecked(std::vector +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "transformer_engine/comm_window.h" + +#ifdef NCCL_HAS_SYMMEM_SUPPORT +#include +#endif + +#include "../common.h" +#include "../extensions.h" +#include "transformer_engine/gemm.h" + +namespace transformer_engine::pytorch { + +namespace { + +// EP process group name, captured at ep_initialize. Used by the symm-mem +// window resolver below to look up SymmetricMemory for payload tensors. +// Empty until ep_initialize. +std::string g_ep_group_name; // NOLINT(runtime/string) + +// True while the EP backend holds a borrowed reference to torch's NCCL comm. +bool g_ep_initialized = false; + +// Zero-copy IO toggle. Placeholder for the symm-mem fast path; per-step ops +// always pass kNoWindow in this release regardless of the toggle. Wired up +// so the switch is a one-line change when the backend lands the fast path. +// Atomic so the Python-side toggle is safe against concurrent +// ep_dispatch/combine (which release the GIL). +std::atomic g_zero_copy_enabled{false}; + +// Per-step ops always pass kNoWindow in this release; the symm-mem IO path is +// planned for a near-future release. +constexpr NVTECommWindow kNoWindow = {nullptr, 0}; + +// Resolve ``t`` to an NCCL symm-mem window for the zero-copy one-sided path. +// Returns ``kNoWindow`` when symm-mem support isn't compiled in, zero-copy is +// disabled, no group is set, or ``t`` isn't symm-mem-backed. Currently unused +// at per-step call sites (they hardcode kNoWindow); kept so flipping +// ``g_zero_copy_enabled`` is the only change needed once the backend's +// symm-mem IO path is exposed. +[[maybe_unused]] NVTECommWindow maybe_make_window(const at::Tensor& t) { +#ifdef NCCL_HAS_SYMMEM_SUPPORT + if (!g_zero_copy_enabled.load(std::memory_order_relaxed)) return kNoWindow; + if (g_ep_group_name.empty()) return kNoWindow; + auto sm = c10d::symmetric_memory::rendezvous(t, g_ep_group_name); + if (sm == nullptr) return kNoWindow; + auto* nccl_sm = dynamic_cast(sm.get()); + NVTE_CHECK(nccl_sm != nullptr, + "Symm-mem backend mismatch: expected NCCLSymmetricMemory. Set the backend to " + "\"NCCL\" before allocating EP payload buffers."); + return NVTECommWindow{static_cast(nccl_sm->get_window()), + static_cast(nccl_sm->get_offset())}; +#else + (void)t; + return kNoWindow; +#endif +} + +// The backend only accepts int64 topk_idx. The PyTorch wrapper enforces this +// at the boundary so the per-step ops don't need an upcast workspace. +void check_topk_idx_int64(at::Tensor topk_idx) { + NVTE_CHECK(topk_idx.is_contiguous(), "topk_idx must be contiguous"); + NVTE_CHECK(topk_idx.scalar_type() == at::kLong, + "topk_idx must be int64; got dtype=", c10::toString(topk_idx.scalar_type()), + ". Cast with topk_idx.long() before calling."); +} + +using Shape = std::vector; + +} // namespace + +bool ep_get_zero_copy() { return g_zero_copy_enabled.load(std::memory_order_relaxed); } + +// ── Bootstrap ──────────────────────────────────────────────────────────────── +// Borrows torch's NCCL host comm (from ``ProcessGroupNCCL._comm_ptr()``). +// ``group_name`` is captured for the symm-mem window resolver. + +void ep_initialize(uintptr_t comm_ptr, const std::string& group_name, int64_t num_experts, + int64_t max_tokens_per_rank, int64_t max_recv_tokens_per_rank, + int64_t hidden_dim, int64_t max_num_sms, + pybind11::object max_token_dtype, bool zero_copy) { + NVTE_CHECK(!group_name.empty(), "group_name must be non-empty (used for symm-mem lookup)"); + NVTE_CHECK(comm_ptr != 0, "comm_ptr must be non-null (torch NCCL host comm pointer)"); + NVTE_CHECK(!g_ep_initialized, "ep_initialize called twice without ep_finalize"); + + auto ep_comm = reinterpret_cast(comm_ptr); + int ep_size = 0; + NVTE_CHECK(ncclCommCount(ep_comm, &ep_size) == ncclSuccess, "ncclCommCount failed"); + auto torch_dtype = max_token_dtype.cast(); + NVTEEpGroupConfig cfg{ + /*ep_size=*/ep_size, + /*num_experts=*/static_cast(num_experts), + /*max_tokens_per_rank=*/static_cast(max_tokens_per_rank), + /*max_recv_tokens_per_rank=*/static_cast(max_recv_tokens_per_rank), + /*hidden_dim=*/static_cast(hidden_dim), + /*max_num_sms=*/static_cast(max_num_sms), + /*max_token_dtype=*/static_cast(GetTransformerEngineDType(torch_dtype)), + /*zero_copy=*/zero_copy ? 1 : 0, + }; + nvte_ep_initialize(static_cast(ep_comm), cfg); + g_zero_copy_enabled.store(zero_copy, std::memory_order_relaxed); + g_ep_initialized = true; + g_ep_group_name = group_name; +} + +void ep_finalize() { + if (!g_ep_initialized) return; + // The borrowed comm is owned by torch's symm-mem layer; don't destroy it. + nvte_ep_shutdown(); + g_ep_initialized = false; + g_ep_group_name.clear(); + g_zero_copy_enabled.store(false, std::memory_order_relaxed); +} + +namespace { + +NVTEEpLayerConfig make_layer_cfg(int64_t top_k, int64_t dispatch_output_per_expert_alignment) { + return NVTEEpLayerConfig{ + /*top_k=*/static_cast(top_k), + /*dispatch_output_per_expert_alignment=*/ + static_cast(dispatch_output_per_expert_alignment), + }; +} + +} // namespace + +int64_t ep_handle_mem_size(int64_t top_k, int64_t dispatch_output_per_expert_alignment) { + return static_cast( + nvte_ep_handle_mem_size(make_layer_cfg(top_k, dispatch_output_per_expert_alignment))); +} + +// ── Per-step ops ───────────────────────────────────────────────────────────── + +void ep_prepare(at::Tensor handle_mem, at::Tensor topk_idx, at::Tensor token_counts, int64_t top_k, + int64_t dispatch_output_per_expert_alignment) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + NVTE_CHECK(topk_idx.dim() >= 2, "topk_idx must be at least 2D [..., top_k]"); + check_topk_idx_int64(topk_idx); + const size_t T_flat = topk_idx.numel() / topk_idx.size(-1); + const size_t topk_n = static_cast(topk_idx.size(-1)); + + auto topk_idx_te = + makeTransformerEngineTensor(topk_idx.data_ptr(), Shape{T_flat, topk_n}, DType::kInt64); + auto token_counts_te = makeTransformerEngineTensor( + token_counts.data_ptr(), Shape{static_cast(token_counts.numel())}, DType::kInt32); + auto handle_mem_te = makeTransformerEngineTensor( + handle_mem.data_ptr(), Shape{static_cast(handle_mem.numel())}, DType::kByte); + + nvte_ep_prepare(handle_mem_te.data(), topk_idx_te.data(), token_counts_te.data(), + make_layer_cfg(top_k, dispatch_output_per_expert_alignment), stream); +} + +void ep_dispatch(at::Tensor handle_mem, at::Tensor topk_idx, at::Tensor tokens, + at::Tensor topk_weights, at::Tensor recv_tokens, at::Tensor recv_topk_weights) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + NVTE_CHECK(tokens.dim() >= 2, "tokens must be at least 2D [..., H]"); + NVTE_CHECK(topk_idx.dim() >= 2, "topk_idx must be at least 2D [..., top_k]"); + NVTE_CHECK(topk_weights.dim() >= 2, "topk_weights must be at least 2D [..., top_k]"); + NVTE_CHECK(recv_tokens.dim() >= 2, "recv_tokens must be at least 2D [..., recv_pr, H]"); + check_topk_idx_int64(topk_idx); + + const size_t H = static_cast(tokens.size(-1)); + const size_t T_flat = tokens.numel() / H; + const size_t topk_n = static_cast(topk_idx.size(-1)); + const size_t recv_pr = recv_tokens.numel() / H; + + NVTE_CHECK(static_cast(topk_weights.size(-1)) == topk_n, + "topk_weights last dim must equal topk_idx last dim"); + NVTE_CHECK(static_cast(recv_topk_weights.numel()) == recv_pr, + "recv_topk_weights total size must equal recv_tokens recv_pr"); + NVTE_CHECK(recv_tokens.scalar_type() == tokens.scalar_type(), "recv_tokens dtype (", + c10::toString(recv_tokens.scalar_type()), ") must match tokens dtype (", + c10::toString(tokens.scalar_type()), ")"); + + auto tok_dtype = GetTransformerEngineDType(tokens.scalar_type()); + auto handle_mem_te = makeTransformerEngineTensor( + handle_mem.data_ptr(), Shape{static_cast(handle_mem.numel())}, DType::kByte); + auto topk_idx_te = + makeTransformerEngineTensor(topk_idx.data_ptr(), Shape{T_flat, topk_n}, DType::kInt64); + auto tokens_te = makeTransformerEngineTensor(tokens.data_ptr(), Shape{T_flat, H}, tok_dtype); + auto topk_w_te = + makeTransformerEngineTensor(topk_weights.data_ptr(), Shape{T_flat, topk_n}, DType::kFloat32); + auto recv_tokens_te = + makeTransformerEngineTensor(recv_tokens.data_ptr(), Shape{recv_pr, H}, tok_dtype); + auto recv_topk_w_te = + makeTransformerEngineTensor(recv_topk_weights.data_ptr(), Shape{recv_pr}, DType::kFloat32); + + // top_k / alignment are carried by the cached layer_cfg seeded at ep_prepare; + // per-step ops look them up by handle_mem pointer in the backend. + nvte_ep_dispatch(handle_mem_te.data(), topk_idx_te.data(), tokens_te.data(), kNoWindow, + topk_w_te.data(), kNoWindow, recv_tokens_te.data(), kNoWindow, + recv_topk_w_te.data(), kNoWindow, stream); +} + +void ep_combine(at::Tensor handle_mem, at::Tensor expert_out, at::Tensor result) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + NVTE_CHECK(expert_out.dim() >= 2, "expert_out must be at least 2D [..., recv_pr, H]"); + NVTE_CHECK(result.dim() >= 2, "result must be at least 2D [..., H]"); + + const size_t H = static_cast(expert_out.size(-1)); + const size_t recv_pr = expert_out.numel() / H; + const size_t T_flat = result.numel() / H; + NVTE_CHECK(static_cast(result.size(-1)) == H, + "result hidden dim must equal expert_out hidden dim"); + NVTE_CHECK(result.scalar_type() == expert_out.scalar_type(), "result dtype (", + c10::toString(result.scalar_type()), ") must match expert_out dtype (", + c10::toString(expert_out.scalar_type()), ")"); + + auto eo_dtype = GetTransformerEngineDType(expert_out.scalar_type()); + auto handle_mem_te = makeTransformerEngineTensor( + handle_mem.data_ptr(), Shape{static_cast(handle_mem.numel())}, DType::kByte); + auto expert_out_te = + makeTransformerEngineTensor(expert_out.data_ptr(), Shape{recv_pr, H}, eo_dtype); + auto result_te = makeTransformerEngineTensor(result.data_ptr(), Shape{T_flat, H}, eo_dtype); + + nvte_ep_combine(handle_mem_te.data(), expert_out_te.data(), kNoWindow, result_te.data(), stream); +} + +void ep_dispatch_bwd(at::Tensor handle_mem, at::Tensor grad, at::Tensor g_recv_topk_weights, + at::Tensor grad_tokens, at::Tensor grad_topk_weights) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + NVTE_CHECK(grad.dim() >= 2, "grad must be at least 2D [..., recv_pr, H]"); + NVTE_CHECK(grad_tokens.dim() >= 2, "grad_tokens must be at least 2D [..., H]"); + NVTE_CHECK(grad_topk_weights.dim() >= 2, "grad_topk_weights must be at least 2D [..., top_k]"); + + const size_t H = static_cast(grad.size(-1)); + const size_t recv_pr = grad.numel() / H; + const size_t T_flat = grad_tokens.numel() / H; + const size_t topk_n = static_cast(grad_topk_weights.size(-1)); + NVTE_CHECK(static_cast(g_recv_topk_weights.numel()) == recv_pr, + "g_recv_topk_weights total size must equal grad recv_pr"); + NVTE_CHECK(static_cast(grad_tokens.size(-1)) == H, + "grad_tokens hidden dim must equal grad H"); + NVTE_CHECK(static_cast(grad_topk_weights.numel()) == T_flat * topk_n, + "grad_topk_weights numel (", grad_topk_weights.numel(), + ") must equal T_flat * top_k (", T_flat * topk_n, ")"); + NVTE_CHECK(grad_tokens.scalar_type() == grad.scalar_type(), "grad_tokens dtype (", + c10::toString(grad_tokens.scalar_type()), ") must match grad dtype (", + c10::toString(grad.scalar_type()), ")"); + + auto g_dtype = GetTransformerEngineDType(grad.scalar_type()); + auto handle_mem_te = makeTransformerEngineTensor( + handle_mem.data_ptr(), Shape{static_cast(handle_mem.numel())}, DType::kByte); + auto grad_te = makeTransformerEngineTensor(grad.data_ptr(), Shape{recv_pr, H}, g_dtype); + auto g_recv_w_te = + makeTransformerEngineTensor(g_recv_topk_weights.data_ptr(), Shape{recv_pr}, DType::kFloat32); + auto grad_tokens_te = + makeTransformerEngineTensor(grad_tokens.data_ptr(), Shape{T_flat, H}, g_dtype); + auto grad_topk_w_te = makeTransformerEngineTensor(grad_topk_weights.data_ptr(), + Shape{T_flat, topk_n}, DType::kFloat32); + + nvte_ep_dispatch_bwd(handle_mem_te.data(), grad_te.data(), kNoWindow, g_recv_w_te.data(), + kNoWindow, grad_tokens_te.data(), grad_topk_w_te.data(), stream); +} + +void ep_combine_bwd(at::Tensor handle_mem, at::Tensor grad, at::Tensor grad_expert_out) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + NVTE_CHECK(grad.dim() >= 2, "grad must be at least 2D [..., H]"); + NVTE_CHECK(grad_expert_out.dim() >= 2, "grad_expert_out must be at least 2D [..., recv_pr, H]"); + + const size_t H = static_cast(grad.size(-1)); + const size_t T_flat = grad.numel() / H; + const size_t recv_pr = grad_expert_out.numel() / H; + NVTE_CHECK(static_cast(grad_expert_out.size(-1)) == H, + "grad_expert_out hidden dim must match grad H"); + NVTE_CHECK(grad_expert_out.scalar_type() == grad.scalar_type(), "grad_expert_out dtype (", + c10::toString(grad_expert_out.scalar_type()), ") must match grad dtype (", + c10::toString(grad.scalar_type()), ")"); + + auto g_dtype = GetTransformerEngineDType(grad.scalar_type()); + auto handle_mem_te = makeTransformerEngineTensor( + handle_mem.data_ptr(), Shape{static_cast(handle_mem.numel())}, DType::kByte); + auto grad_te = makeTransformerEngineTensor(grad.data_ptr(), Shape{T_flat, H}, g_dtype); + auto grad_eo_te = + makeTransformerEngineTensor(grad_expert_out.data_ptr(), Shape{recv_pr, H}, g_dtype); + + nvte_ep_combine_bwd(handle_mem_te.data(), grad_te.data(), kNoWindow, grad_eo_te.data(), + kNoWindow, stream); +} + +void register_ep_bindings(pybind11::module_& m) { + namespace py = pybind11; + m.def("ep_initialize", &ep_initialize, + "Initialize the EP backend; borrows torch's NCCL comm pointed to by ``comm_ptr``.", + py::arg("comm_ptr"), py::arg("group_name"), py::arg("num_experts"), + py::arg("max_tokens_per_rank"), py::arg("max_recv_tokens_per_rank"), py::arg("hidden_dim"), + py::arg("max_num_sms") = 0, py::arg("max_token_dtype"), py::arg("zero_copy") = false, + py::call_guard()); + m.def("ep_finalize", &ep_finalize, "Tear down the EP backend. Idempotent.", + py::call_guard()); + m.def("ep_get_zero_copy", &ep_get_zero_copy, "Return the current EP zero-copy toggle state."); + m.def("ep_handle_mem_size", &ep_handle_mem_size, + "Return the handle_mem byte size for the given layer config.", py::arg("top_k"), + py::arg("dispatch_output_per_expert_alignment") = 0); + m.def("ep_prepare", &ep_prepare, "EP prepare", py::call_guard()); + m.def("ep_dispatch", &ep_dispatch, "EP dispatch", py::call_guard()); + m.def("ep_combine", &ep_combine, "EP combine", py::call_guard()); + m.def("ep_dispatch_bwd", &ep_dispatch_bwd, "EP dispatch backward", + py::call_guard()); + m.def("ep_combine_bwd", &ep_combine_bwd, "EP combine backward", + py::call_guard()); +} + +} // namespace transformer_engine::pytorch + +#endif // NVTE_WITH_NCCL_EP diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index d6089b1e01..e55b6defa0 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -292,6 +292,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "DSquaredReLU + DBias + Quantize", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); +#ifdef NVTE_WITH_NCCL_EP + transformer_engine::pytorch::register_ep_bindings(m); +#endif // NVTE_WITH_NCCL_EP + // Permutation functions m.def("moe_permute_fwd", transformer_engine::pytorch::moe_permute_fwd, "MOE permute FWD", py::call_guard()); diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py new file mode 100644 index 0000000000..bc4b3bb5d1 --- /dev/null +++ b/transformer_engine/pytorch/ep.py @@ -0,0 +1,734 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""PyTorch Expert Parallelism (EP) API.""" + +from __future__ import annotations + +import atexit +from contextlib import contextmanager +from typing import Iterator, Optional + +import torch +import torch.distributed as dist + +import transformer_engine_torch as tex + + +__all__ = [ + "EpHandle", + "EpBuffer", + "ep_bootstrap", + "ep_finalize", + "ep_scope", + "ep_dispatch", + "ep_combine", + "symm_mem_alloc", +] + + +# Symmetric-memory buffer allocator +# +# Used for the symm-mem zero-copy IO path. Set ``ep_bootstrap(zero_copy=True)`` +# to opt in; the C++ backend then operates the EP group in zero-copy mode. + + +def symm_mem_alloc( + shape, + dtype: torch.dtype, + ep_group: dist.ProcessGroup, + device: Optional[torch.device] = None, +) -> torch.Tensor: + """Allocate and rendezvous a symm-mem buffer on ep_group. Collective on ep_group.""" + if device is None: + device = torch.device("cuda", torch.cuda.current_device()) + try: + from torch.distributed import _symmetric_memory as _symm_mem + except ImportError as e: + raise RuntimeError( + "torch.distributed._symmetric_memory is unavailable; symm_mem_alloc " + "requires PyTorch built with NCCL symm-mem support." + ) from e + if _symm_mem.get_backend(device) != "NCCL": + _symm_mem.set_backend("NCCL") + t = _symm_mem.empty(*shape, dtype=dtype, device=device) + _symm_mem.rendezvous(t, group=ep_group.group_name) + return t + + +# Bootstrap + + +# NCCL EP requires NCCL >= 2.30.4 (matches the C++ backend's runtime check). +_MIN_NCCL_VERSION = (2, 30, 4) + + +def _check_nccl_runtime_version() -> None: + """Raise with a clear message if the loaded libnccl is too old for NCCL EP.""" + import ctypes + + try: + lib = ctypes.CDLL("libnccl.so.2", mode=ctypes.RTLD_GLOBAL) + v = ctypes.c_int(0) + if lib.ncclGetVersion(ctypes.byref(v)) != 0: + import warnings + + warnings.warn("ncclGetVersion failed; skipping NCCL EP version check.") + return + except OSError: # libnccl not findable; let the C++ side error + return + n = v.value + # NCCL packs as (major*10000 + minor*100 + patch) up to ~2.x; newer + # builds use the same scheme. Decode defensively. + major, minor, patch = n // 10000, (n // 100) % 100, n % 100 + if (major, minor, patch) < _MIN_NCCL_VERSION: + min_str = ".".join(str(x) for x in _MIN_NCCL_VERSION) + raise RuntimeError( + f"NCCL EP requires NCCL >= {min_str}, found {major}.{minor}.{patch} at runtime. " + "Set LD_LIBRARY_PATH to a newer libnccl.so before launching." + ) + + +_BOOTSTRAPPED = False +_ATEXIT_REGISTERED = False + + +def _atexit_finalize() -> None: + """Best-effort teardown at interpreter shutdown; swallows errors.""" + global _BOOTSTRAPPED + if _BOOTSTRAPPED: + try: + tex.ep_finalize() + except Exception: + import traceback + + traceback.print_exc() + finally: + _BOOTSTRAPPED = False + + +def ep_bootstrap( + ep_group: dist.ProcessGroup, + num_experts: int, + max_tokens_per_rank: int, + recv_capacity_per_rank: int, + hidden_dim: int, + max_num_sms: int = 0, + zero_copy: bool = False, + max_token_dtype: torch.dtype = torch.bfloat16, +) -> None: + """Initialize EP by borrowing ep_group's NCCL comm. Call once per process. + + max_token_dtype sets the widest token dtype this EP group will dispatch; + it sizes NCCL EP staging buffers. + + ``zero_copy`` opts the EP group into the symm-mem zero-copy IO path; pass + ``True`` only when payload tensors are allocated via ``symm_mem_alloc``. + Defaults to ``False``. + """ + global _BOOTSTRAPPED, _ATEXIT_REGISTERED + if _BOOTSTRAPPED: + raise RuntimeError("ep_bootstrap was already called in this process") + if ep_group.size() < 2: + raise ValueError(f"ep_bootstrap requires ep_group.size() >= 2 (got {ep_group.size()}).") + _check_nccl_runtime_version() + + # Materialize the PG's NCCL comm before borrowing its raw handle. + dist.barrier(group=ep_group, device_ids=[torch.cuda.current_device()]) + comm_ptr = ep_group._get_backend(torch.device("cuda"))._comm_ptr() + + tex.ep_initialize( + int(comm_ptr), + str(ep_group.group_name), + int(num_experts), + int(max_tokens_per_rank), + int(recv_capacity_per_rank), + int(hidden_dim), + int(max_num_sms), + max_token_dtype, + bool(zero_copy), + ) + _BOOTSTRAPPED = True + if not _ATEXIT_REGISTERED: + atexit.register(_atexit_finalize) + _ATEXIT_REGISTERED = True + + +def ep_finalize() -> None: + """Explicit EP teardown; optional and idempotent. An atexit handler covers + normal shutdown; call this only before ``dist.destroy_process_group()``, + since the borrowed NCCL comm is invalid once the PG is destroyed. + + Propagates errors from the C++ teardown; use ``_atexit_finalize`` for the + best-effort interpreter-shutdown path. + """ + global _BOOTSTRAPPED + if not _BOOTSTRAPPED: + return + try: + tex.ep_finalize() + finally: + _BOOTSTRAPPED = False + + +@contextmanager +def ep_scope( + ep_group: dist.ProcessGroup, + num_experts: int, + max_tokens_per_rank: int, + recv_capacity_per_rank: int, + hidden_dim: int, + max_num_sms: int = 0, + zero_copy: bool = False, + max_token_dtype: torch.dtype = torch.bfloat16, +) -> Iterator[None]: + """Context manager: ``ep_bootstrap`` on enter, ``ep_finalize`` on exit. + + Use when you tear down the EP process group yourself, so the borrowed NCCL + comm is released before ``dist.destroy_process_group()``. + """ + ep_bootstrap( + ep_group, + num_experts=num_experts, + max_tokens_per_rank=max_tokens_per_rank, + recv_capacity_per_rank=recv_capacity_per_rank, + hidden_dim=hidden_dim, + max_num_sms=max_num_sms, + zero_copy=zero_copy, + max_token_dtype=max_token_dtype, + ) + try: + yield + finally: + ep_finalize() + + +# Handle + + +class EpHandle: + """Routing context for one EP layer. Construct one per concurrently-live + microbatch (e.g. one per in-flight PP-1F1B step). + """ + + __slots__ = ( + "handle_mem", + "top_k", + "alignment", + "max_tokens_per_rank", + "recv_capacity_per_rank", + "hidden_dim", + "num_local_experts", + "payload_dtype", + "device", + ) + + def __init__( + self, + top_k: int, + max_tokens_per_rank: int, + recv_capacity_per_rank: int, + hidden_dim: int, + num_local_experts: int, + alignment: int = 0, + device: Optional[torch.device] = None, + payload_dtype: torch.dtype = torch.bfloat16, + ) -> None: + if device is None: + device = torch.device("cuda", torch.cuda.current_device()) + alignment = int(alignment) + if alignment > 1 and (alignment & (alignment - 1)) != 0: + raise ValueError( + f"alignment must be 0, 1, or a power of two (got {alignment})." + ) + self.top_k = int(top_k) + self.alignment = alignment + self.max_tokens_per_rank = int(max_tokens_per_rank) + self.recv_capacity_per_rank = int(recv_capacity_per_rank) + self.hidden_dim = int(hidden_dim) + self.num_local_experts = int(num_local_experts) + self.payload_dtype = payload_dtype + self.device = device + size_bytes = tex.ep_handle_mem_size(self.top_k, self.alignment) + self.handle_mem = torch.empty(int(size_bytes), dtype=torch.uint8, device=device) + + +# Buffer + + +class EpBuffer: + """Persistent payload and scratch buffers for one EP layer. + + All slots are plain HBM in TE 2.17 (the symm-mem IO fast path is planned + for a near-future release). + + Use one EpBuffer per concurrently-in-flight call on the layer (one per + PP-1F1B microbatch); sharing between an outstanding fwd and a later call + overwrites tensors the earlier bwd still reads. Call record_stream from + streams other than the allocation stream. + """ + + __slots__ = ( + "recv_tokens", + "combine_in", + "recv_topk_weights", + "token_counts", + "grad_tokens", + "grad_topk_weights", + ) + + def __init__( + self, + handle: EpHandle, + ep_group: Optional[dist.ProcessGroup] = None, + *, + device: Optional[torch.device] = None, + ) -> None: + """Allocate the persistent EP slots. + + Cross-rank slots are symm-mem-backed when ``ep_bootstrap`` was called + with ``zero_copy=True`` (requires ``ep_group``); otherwise plain HBM. + """ + if device is None: + device = torch.device("cuda", torch.cuda.current_device()) + recv_shape = (handle.recv_capacity_per_rank, handle.hidden_dim) + send_shape = (handle.max_tokens_per_rank, handle.hidden_dim) + zero_copy = bool(tex.ep_get_zero_copy()) + if zero_copy: + if ep_group is None: + raise ValueError("EpBuffer requires ep_group when ep_bootstrap(zero_copy=True).") + self.recv_tokens = symm_mem_alloc( + recv_shape, handle.payload_dtype, ep_group, device=device + ) + self.combine_in = symm_mem_alloc( + recv_shape, handle.payload_dtype, ep_group, device=device + ) + self.recv_topk_weights = symm_mem_alloc( + (handle.recv_capacity_per_rank,), torch.float32, ep_group, device=device + ) + self.grad_tokens = symm_mem_alloc( + send_shape, handle.payload_dtype, ep_group, device=device + ) + else: + self.recv_tokens = torch.empty(recv_shape, dtype=handle.payload_dtype, device=device) + self.combine_in = torch.empty(recv_shape, dtype=handle.payload_dtype, device=device) + self.recv_topk_weights = torch.empty( + handle.recv_capacity_per_rank, dtype=torch.float32, device=device + ) + self.grad_tokens = torch.empty(send_shape, dtype=handle.payload_dtype, device=device) + # Per-rank scratch; never cross-rank, plain HBM regardless of mode. + self.token_counts = torch.empty(handle.num_local_experts, dtype=torch.int32, device=device) + self.grad_topk_weights = torch.empty( + (handle.max_tokens_per_rank, handle.top_k), dtype=torch.float32, device=device + ) + + @classmethod + def from_external( + cls, + handle: EpHandle, + *, + recv_tokens: torch.Tensor, + combine_in: torch.Tensor, + recv_topk_weights: Optional[torch.Tensor] = None, + grad_tokens: Optional[torch.Tensor] = None, + token_counts: Optional[torch.Tensor] = None, + grad_topk_weights: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + ) -> "EpBuffer": + """Construct from caller-allocated buffers. + + Useful for sharing a pre-allocated pool across layers/microbatches, and + for plugging in symm-mem-backed tensors once the zero-copy IO fast path + ships in a near-future release. Caller-supplied slots are validated + against the expected shape and dtype; ``None`` slots get a fresh HBM + allocation. + """ + if device is None: + device = torch.device("cuda", torch.cuda.current_device()) + recv_shape = (handle.recv_capacity_per_rank, handle.hidden_dim) + send_shape = (handle.max_tokens_per_rank, handle.hidden_dim) + topk_shape = (handle.max_tokens_per_rank, handle.top_k) + recv_w_shape = (handle.recv_capacity_per_rank,) + counts_shape = (handle.num_local_experts,) + + def _check(t: torch.Tensor, name: str, shape: tuple, dtype: torch.dtype) -> torch.Tensor: + if tuple(t.shape) != shape: + raise ValueError(f"{name} shape {tuple(t.shape)} != expected {shape}") + if t.dtype != dtype: + raise ValueError(f"{name} dtype {t.dtype} != expected {dtype}") + return t + + inst = cls.__new__(cls) + inst.recv_tokens = _check(recv_tokens, "recv_tokens", recv_shape, handle.payload_dtype) + inst.combine_in = _check(combine_in, "combine_in", recv_shape, handle.payload_dtype) + inst.recv_topk_weights = ( + _check(recv_topk_weights, "recv_topk_weights", recv_w_shape, torch.float32) + if recv_topk_weights is not None + else torch.empty(recv_w_shape, dtype=torch.float32, device=device) + ) + inst.grad_tokens = ( + _check(grad_tokens, "grad_tokens", send_shape, handle.payload_dtype) + if grad_tokens is not None + else torch.empty(send_shape, dtype=handle.payload_dtype, device=device) + ) + inst.token_counts = ( + _check(token_counts, "token_counts", counts_shape, torch.int32) + if token_counts is not None + else torch.empty(counts_shape, dtype=torch.int32, device=device) + ) + inst.grad_topk_weights = ( + _check(grad_topk_weights, "grad_topk_weights", topk_shape, torch.float32) + if grad_topk_weights is not None + else torch.empty(topk_shape, dtype=torch.float32, device=device) + ) + return inst + + def record_stream(self, stream: torch.cuda.Stream) -> None: + """Record stream as a user of all owned tensors so the caching allocator + defers reclaim until stream has caught up.""" + for t in ( + self.recv_tokens, + self.combine_in, + self.recv_topk_weights, + self.grad_tokens, + self.token_counts, + self.grad_topk_weights, + ): + t.record_stream(stream) + + +# torch.library custom ops (so they don't graph-break under torch.compile) + +_LIB = "transformer_engine_ep" + + +@torch.library.custom_op( + f"{_LIB}::prepare", + mutates_args=("handle_mem", "token_counts"), + device_types="cuda", +) +def _prepare_op( + handle_mem: torch.Tensor, + top_k: int, + topk_idx: torch.Tensor, + token_counts: torch.Tensor, + alignment: int, +) -> None: + tex.ep_prepare(handle_mem, topk_idx, token_counts, top_k, alignment) + + +@_prepare_op.register_fake +def _(*args, **kw): + return None + + +@torch.library.custom_op( + f"{_LIB}::dispatch", + mutates_args=("recv_tokens", "recv_topk_weights"), + device_types="cuda", +) +def _dispatch_op( + handle_mem: torch.Tensor, + topk_idx: torch.Tensor, + tokens: torch.Tensor, + topk_weights: torch.Tensor, + recv_tokens: torch.Tensor, + recv_topk_weights: torch.Tensor, +) -> None: + tex.ep_dispatch(handle_mem, topk_idx, tokens, topk_weights, recv_tokens, recv_topk_weights) + + +@_dispatch_op.register_fake +def _(*args, **kw): + return None + + +@torch.library.custom_op( + f"{_LIB}::combine", + mutates_args=("result",), + device_types="cuda", +) +def _combine_op( + handle_mem: torch.Tensor, + expert_out: torch.Tensor, + result: torch.Tensor, +) -> None: + tex.ep_combine(handle_mem, expert_out, result) + + +@_combine_op.register_fake +def _(*args, **kw): + return None + + +@torch.library.custom_op( + f"{_LIB}::dispatch_bwd", + mutates_args=("grad_tokens", "grad_topk_weights"), + device_types="cuda", +) +def _dispatch_bwd_op( + handle_mem: torch.Tensor, + grad: torch.Tensor, + g_recv_topk_weights: torch.Tensor, + grad_tokens: torch.Tensor, + grad_topk_weights: torch.Tensor, +) -> None: + tex.ep_dispatch_bwd(handle_mem, grad, g_recv_topk_weights, grad_tokens, grad_topk_weights) + + +@_dispatch_bwd_op.register_fake +def _(*args, **kw): + return None + + +@torch.library.custom_op( + f"{_LIB}::combine_bwd", + mutates_args=("grad_expert_out",), + device_types="cuda", +) +def _combine_bwd_op( + handle_mem: torch.Tensor, + grad: torch.Tensor, + grad_expert_out: torch.Tensor, +) -> None: + tex.ep_combine_bwd(handle_mem, grad, grad_expert_out) + + +@_combine_bwd_op.register_fake +def _(*args, **kw): + return None + + +# Non-autograd primitives + + +def ep_prepare(handle: EpHandle, topk_idx: torch.Tensor) -> torch.Tensor: + """AllGather the routing map; fills handle.handle_mem and returns token_counts + (int32, shape [num_local_experts]). topk_idx must be int64. + """ + token_counts = torch.empty(handle.num_local_experts, dtype=torch.int32, device=handle.device) + torch.ops.transformer_engine_ep.prepare( + handle.handle_mem, handle.top_k, topk_idx, token_counts, handle.alignment + ) + return token_counts + + +def _ep_dispatch_raw( + handle: EpHandle, + topk_idx: torch.Tensor, + tokens: torch.Tensor, + topk_weights: torch.Tensor, + recv_tokens: torch.Tensor, + recv_topk_weights: torch.Tensor, +) -> None: + """Raw dispatch; no autograd, no prepare. Caller must run ep_prepare first.""" + tex.ep_dispatch( + handle.handle_mem, topk_idx, tokens, topk_weights, recv_tokens, recv_topk_weights + ) + + +def _ep_combine_raw(handle: EpHandle, expert_out: torch.Tensor, result: torch.Tensor) -> None: + """Raw combine; no autograd. Caller pre-weights expert_out.""" + tex.ep_combine(handle.handle_mem, expert_out, result) + + +# autograd.Function wrappers + + +class _EpDispatch(torch.autograd.Function): + """Autograd-aware prepare + dispatch. Fwd/bwd share handle_mem and the + EpBuffer slots; do not re-run ep_prepare between them and do not share + EpBuffer with another in-flight call (see EpBuffer). + """ + + @staticmethod + def forward( # type: ignore[override] + ctx, + handle_mem: torch.Tensor, + top_k: int, + alignment: int, + recv_tokens: torch.Tensor, + recv_topk_weights: torch.Tensor, + token_counts: torch.Tensor, + grad_tokens_buf: torch.Tensor, + grad_topk_weights_buf: torch.Tensor, + topk_idx: torch.Tensor, + tokens: torch.Tensor, + topk_weights: torch.Tensor, + ): + torch.ops.transformer_engine_ep.prepare( + handle_mem, top_k, topk_idx, token_counts, alignment + ) + torch.ops.transformer_engine_ep.dispatch( + handle_mem, + topk_idx, + tokens, + topk_weights, + recv_tokens, + recv_topk_weights, + ) + ctx.handle_mem = handle_mem + ctx.grad_tokens_buf = grad_tokens_buf + ctx.grad_topk_weights_buf = grad_topk_weights_buf + ctx.tokens_shape = tokens.shape + ctx.tokens_dtype = tokens.dtype + ctx.topk_weights_shape = topk_weights.shape + ctx.topk_weights_dtype = topk_weights.dtype + ctx.tokens_T_flat = tokens.numel() // tokens.shape[-1] + ctx.topk_T_flat = topk_weights.numel() // topk_weights.shape[-1] + ctx.recv_capacity = recv_tokens.shape[0] + ctx.hidden_dim = tokens.shape[-1] + ctx.mark_non_differentiable(token_counts) + # Detach so the long-lived buffers aren't tracked as differentiable outputs; + # autograd re-attaches grad_fn pointing back at this Function. + return recv_tokens.detach(), recv_topk_weights.detach(), token_counts + + @staticmethod + def backward(ctx, g_recv_tokens, g_recv_topk_weights, _g_token_counts): # type: ignore[override] + device = ctx.handle_mem.device + if g_recv_tokens is None: + g_recv_tokens = torch.zeros( + ctx.recv_capacity, ctx.hidden_dim, dtype=ctx.tokens_dtype, device=device + ) + if g_recv_topk_weights is None: + g_recv_topk_weights = torch.zeros(ctx.recv_capacity, dtype=torch.float32, device=device) + if not g_recv_tokens.is_contiguous(): + g_recv_tokens = g_recv_tokens.contiguous() + if not g_recv_topk_weights.is_contiguous(): + g_recv_topk_weights = g_recv_topk_weights.contiguous() + # Narrow the persistent slots to this call's flattened leading dim. + grad_tokens = ctx.grad_tokens_buf.narrow(0, 0, ctx.tokens_T_flat) + grad_topk_weights = ctx.grad_topk_weights_buf.narrow(0, 0, ctx.topk_T_flat) + torch.ops.transformer_engine_ep.dispatch_bwd( + ctx.handle_mem, + g_recv_tokens, + g_recv_topk_weights, + grad_tokens, + grad_topk_weights, + ) + # Reshape back to the original input shape so autograd's grad slot matches. + grad_tokens_out = grad_tokens.view(ctx.tokens_shape) + grad_topk_weights_out = grad_topk_weights.view(ctx.topk_weights_shape) + return ( + None, # handle_mem + None, # top_k + None, # alignment + None, # recv_tokens + None, # recv_topk_weights + None, # token_counts + None, # grad_tokens_buf + None, # grad_topk_weights_buf + None, # topk_idx + grad_tokens_out, + grad_topk_weights_out, + ) + + +class _EpCombine(torch.autograd.Function): + """Autograd-aware combine. combine_in is reused as grad_combine_in in bwd; + fwd/bwd share handle_mem so don't re-run ep_prepare between them. Caller + must pre-apply the topk weighting to expert_out. + """ + + @staticmethod + def forward( # type: ignore[override] + ctx, + handle_mem: torch.Tensor, + combine_in: torch.Tensor, + num_local_tokens: int, + hidden_dim: int, + expert_out: torch.Tensor, + ): + device = expert_out.device + # Stage expert_out into the persistent combine_in slot (symm-mem-backed + # in the zero-copy path); its storage is reused as grad_combine_in in bwd. + combine_in.copy_(expert_out) + result = torch.empty(num_local_tokens, hidden_dim, dtype=expert_out.dtype, device=device) + torch.ops.transformer_engine_ep.combine(handle_mem, combine_in, result) + ctx.handle_mem = handle_mem + ctx.combine_in = combine_in # reused as grad_combine_in in bwd + return result + + @staticmethod + def backward(ctx, g_result): # type: ignore[override] + grad_combine_in = ctx.combine_in + if not g_result.is_contiguous(): + g_result = g_result.contiguous() + torch.ops.transformer_engine_ep.combine_bwd(ctx.handle_mem, g_result, grad_combine_in) + return ( + None, # handle_mem + None, # combine_in + None, # num_local_tokens + None, # hidden_dim + grad_combine_in, + ) + + +# Public high-level wrappers + + +# FP8 dispatch is not yet supported by the common backend. +_FP8_DTYPES = (torch.float8_e4m3fn, torch.float8_e5m2) + + +def _reject_fp8(*tensors: torch.Tensor) -> None: + for t in tensors: + if t.dtype in _FP8_DTYPES: + raise NotImplementedError( + f"FP8 dispatch/combine not supported (got dtype={t.dtype}); " + "quantize outside the EP boundary." + ) + + +def ep_dispatch( + handle: EpHandle, + buffer: EpBuffer, + tokens: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, +): + """Run prepare + dispatch with autograd. topk_idx must be int64. + + Returns (recv_tokens, recv_topk_weights, token_counts); views into buffer's + persistent slots; consume them before the next ep_dispatch on the same + buffer or they get overwritten. token_counts is non-differentiable. + """ + _reject_fp8(tokens, buffer.recv_tokens) + return _EpDispatch.apply( + handle.handle_mem, + handle.top_k, + handle.alignment, + buffer.recv_tokens, + buffer.recv_topk_weights, + buffer.token_counts, + buffer.grad_tokens, + buffer.grad_topk_weights, + topk_idx, + tokens, + topk_weights, + ) + + +def ep_combine( + handle: EpHandle, + buffer: EpBuffer, + expert_out: torch.Tensor, + *, + num_local_tokens: Optional[int] = None, +): + """Combine expert outputs back to the source rank, with autograd. The + caller must pre-apply the topk weighting to expert_out. + + Result shape is (num_local_tokens, handle.hidden_dim); defaults to + handle.max_tokens_per_rank rows. + """ + _reject_fp8(expert_out, buffer.combine_in) + if num_local_tokens is None: + num_local_tokens = handle.max_tokens_per_rank + return _EpCombine.apply( + handle.handle_mem, + buffer.combine_in, + num_local_tokens, + handle.hidden_dim, + expert_out, + ) From dcc3172f725aedb36bab919645d8efb2a11cd00a Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 9 Jun 2026 16:28:36 -0700 Subject: [PATCH 02/18] EP PyTorch: wire maybe_make_window into per-step ops for zero_copy Signed-off-by: Phuong Nguyen --- .../pytorch/csrc/extensions/ep.cpp | 41 ++++++++++++------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/ep.cpp b/transformer_engine/pytorch/csrc/extensions/ep.cpp index cdb0b4239a..018b5addc0 100644 --- a/transformer_engine/pytorch/csrc/extensions/ep.cpp +++ b/transformer_engine/pytorch/csrc/extensions/ep.cpp @@ -59,15 +59,18 @@ constexpr NVTECommWindow kNoWindow = {nullptr, 0}; // Resolve ``t`` to an NCCL symm-mem window for the zero-copy one-sided path. // Returns ``kNoWindow`` when symm-mem support isn't compiled in, zero-copy is -// disabled, no group is set, or ``t`` isn't symm-mem-backed. Currently unused -// at per-step call sites (they hardcode kNoWindow); kept so flipping -// ``g_zero_copy_enabled`` is the only change needed once the backend's -// symm-mem IO path is exposed. -[[maybe_unused]] NVTECommWindow maybe_make_window(const at::Tensor& t) { +// disabled, no group is set, or ``t`` isn't symm-mem-backed; callers pass the +// resulting window unconditionally to the backend. +NVTECommWindow maybe_make_window(const at::Tensor& t) { #ifdef NCCL_HAS_SYMMEM_SUPPORT if (!g_zero_copy_enabled.load(std::memory_order_relaxed)) return kNoWindow; if (g_ep_group_name.empty()) return kNoWindow; - auto sm = c10d::symmetric_memory::rendezvous(t, g_ep_group_name); + c10::intrusive_ptr sm; + try { + sm = c10d::symmetric_memory::rendezvous(t, g_ep_group_name); + } catch (const std::exception&) { + return kNoWindow; // Tensor not symm-mem-backed; fall back to staged copy. + } if (sm == nullptr) return kNoWindow; auto* nccl_sm = dynamic_cast(sm.get()); NVTE_CHECK(nccl_sm != nullptr, @@ -212,9 +215,13 @@ void ep_dispatch(at::Tensor handle_mem, at::Tensor topk_idx, at::Tensor tokens, // top_k / alignment are carried by the cached layer_cfg seeded at ep_prepare; // per-step ops look them up by handle_mem pointer in the backend. - nvte_ep_dispatch(handle_mem_te.data(), topk_idx_te.data(), tokens_te.data(), kNoWindow, - topk_w_te.data(), kNoWindow, recv_tokens_te.data(), kNoWindow, - recv_topk_w_te.data(), kNoWindow, stream); + NVTECommWindow tokens_win = maybe_make_window(tokens); + NVTECommWindow topk_w_win = maybe_make_window(topk_weights); + NVTECommWindow recv_tokens_win = maybe_make_window(recv_tokens); + NVTECommWindow recv_topk_w_win = maybe_make_window(recv_topk_weights); + nvte_ep_dispatch(handle_mem_te.data(), topk_idx_te.data(), tokens_te.data(), tokens_win, + topk_w_te.data(), topk_w_win, recv_tokens_te.data(), recv_tokens_win, + recv_topk_w_te.data(), recv_topk_w_win, stream); } void ep_combine(at::Tensor handle_mem, at::Tensor expert_out, at::Tensor result) { @@ -238,7 +245,9 @@ void ep_combine(at::Tensor handle_mem, at::Tensor expert_out, at::Tensor result) makeTransformerEngineTensor(expert_out.data_ptr(), Shape{recv_pr, H}, eo_dtype); auto result_te = makeTransformerEngineTensor(result.data_ptr(), Shape{T_flat, H}, eo_dtype); - nvte_ep_combine(handle_mem_te.data(), expert_out_te.data(), kNoWindow, result_te.data(), stream); + NVTECommWindow expert_out_win = maybe_make_window(expert_out); + nvte_ep_combine(handle_mem_te.data(), expert_out_te.data(), expert_out_win, result_te.data(), + stream); } void ep_dispatch_bwd(at::Tensor handle_mem, at::Tensor grad, at::Tensor g_recv_topk_weights, @@ -274,8 +283,10 @@ void ep_dispatch_bwd(at::Tensor handle_mem, at::Tensor grad, at::Tensor g_recv_t auto grad_topk_w_te = makeTransformerEngineTensor(grad_topk_weights.data_ptr(), Shape{T_flat, topk_n}, DType::kFloat32); - nvte_ep_dispatch_bwd(handle_mem_te.data(), grad_te.data(), kNoWindow, g_recv_w_te.data(), - kNoWindow, grad_tokens_te.data(), grad_topk_w_te.data(), stream); + NVTECommWindow grad_win = maybe_make_window(grad); + NVTECommWindow g_recv_w_win = maybe_make_window(g_recv_topk_weights); + nvte_ep_dispatch_bwd(handle_mem_te.data(), grad_te.data(), grad_win, g_recv_w_te.data(), + g_recv_w_win, grad_tokens_te.data(), grad_topk_w_te.data(), stream); } void ep_combine_bwd(at::Tensor handle_mem, at::Tensor grad, at::Tensor grad_expert_out) { @@ -299,8 +310,10 @@ void ep_combine_bwd(at::Tensor handle_mem, at::Tensor grad, at::Tensor grad_expe auto grad_eo_te = makeTransformerEngineTensor(grad_expert_out.data_ptr(), Shape{recv_pr, H}, g_dtype); - nvte_ep_combine_bwd(handle_mem_te.data(), grad_te.data(), kNoWindow, grad_eo_te.data(), - kNoWindow, stream); + NVTECommWindow grad_win = maybe_make_window(grad); + NVTECommWindow grad_eo_win = maybe_make_window(grad_expert_out); + nvte_ep_combine_bwd(handle_mem_te.data(), grad_te.data(), grad_win, grad_eo_te.data(), + grad_eo_win, stream); } void register_ep_bindings(pybind11::module_& m) { From ee0df627ed727f618bffd84e9dc2c39cde6d43d5 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 9 Jun 2026 16:28:36 -0700 Subject: [PATCH 03/18] EP PyTorch: merge EpHandle into EpBuffer; ep_dispatch/ep_combine take a single buffer Signed-off-by: Phuong Nguyen --- examples/pytorch/ep/bench/ep_bench.py | 27 ++-- examples/pytorch/ep/ep_moe.py | 17 +-- tests/pytorch/distributed/run_ep.py | 86 ++++++----- transformer_engine/pytorch/ep.py | 202 +++++++++++++------------- 4 files changed, 159 insertions(+), 173 deletions(-) diff --git a/examples/pytorch/ep/bench/ep_bench.py b/examples/pytorch/ep/bench/ep_bench.py index 86217b7f91..973703b710 100644 --- a/examples/pytorch/ep/bench/ep_bench.py +++ b/examples/pytorch/ep/bench/ep_bench.py @@ -30,7 +30,6 @@ from transformer_engine.pytorch.ep import ( EpBuffer, - EpHandle, ep_bootstrap, ep_combine, ep_dispatch, @@ -177,14 +176,14 @@ def main(): topk_idx, tokens_hbm, topk_w_hbm = _make_inputs(rank, world_size, T, H, K, E, device) - handle = EpHandle( + buffer = EpBuffer( top_k=K, max_tokens_per_rank=T, recv_capacity_per_rank=recv_pr, hidden_dim=H, num_local_experts=num_local_experts, + ep_group=ep_group, ) - buffer = EpBuffer(handle) tokens = tokens_hbm topk_w = topk_w_hbm @@ -192,11 +191,11 @@ def main(): recv_w = torch.empty(recv_pr, dtype=torch.float32, device=device) # -- Prepare once outside the timed loops ------------------------------ - ep_prepare(handle, topk_idx) + ep_prepare(buffer, topk_idx) torch.cuda.synchronize() # Pre-dispatch a steady recv_tokens / recv_w so combine stages have valid input. - _ep_dispatch_raw(handle, topk_idx, tokens, topk_w, recv_tokens, recv_w) + _ep_dispatch_raw(buffer, topk_idx, tokens, topk_w, recv_tokens, recv_w) torch.cuda.synchronize() # fp-equivalent stand-in for an MLP output. expert_out = recv_tokens.clone() @@ -210,20 +209,20 @@ def main(): eo_p = recv_tokens.detach().clone().requires_grad_(True) # Stand-in callables; the cuda-graph branch below swaps in graphed versions. - fwd_bwd_dispatch_fn = lambda x: ep_dispatch(handle, buffer, x, topk_idx, topk_w)[ # noqa: E731 + fwd_bwd_dispatch_fn = lambda x: ep_dispatch(buffer, x, topk_idx, topk_w)[ # noqa: E731 0 ] - fwd_bwd_combine_fn = lambda eo: ep_combine(handle, buffer, eo) # noqa: E731 + fwd_bwd_combine_fn = lambda eo: ep_combine(buffer, eo) # noqa: E731 def _dispatch_raw(): - _ep_dispatch_raw(handle, topk_idx, tokens, topk_w, recv_tokens, recv_w) + _ep_dispatch_raw(buffer, topk_idx, tokens, topk_w, recv_tokens, recv_w) def _combine_raw(): out_buf = torch.empty(T, H, dtype=torch.bfloat16, device=device) - _ep_combine_raw(handle, expert_out, out_buf) + _ep_combine_raw(buffer, expert_out, out_buf) def _ep_dispatch_fwd(): - ep_dispatch(handle, buffer, tokens.detach(), topk_idx, topk_w) + ep_dispatch(buffer, tokens.detach(), topk_idx, topk_w) def _ep_dispatch_fwd_bwd(): tokens_p.grad = None @@ -231,7 +230,7 @@ def _ep_dispatch_fwd_bwd(): (0.5 * (r * r).sum(dtype=torch.float32)).backward() def _ep_combine_fwd(): - ep_combine(handle, buffer, recv_tokens) + ep_combine(buffer, recv_tokens) def _ep_combine_fwd_bwd(): eo_p.grad = None @@ -265,11 +264,11 @@ def _ep_combine_fwd_bwd(): # Graph fwd+bwd of the autograd-wrapped ops via make_graphed_callables. class _DispatchMod(torch.nn.Module): def forward(self, x): - return ep_dispatch(handle, buffer, x, topk_idx, topk_w)[0] + return ep_dispatch(buffer, x, topk_idx, topk_w)[0] class _CombineMod(torch.nn.Module): def forward(self, eo): - return ep_combine(handle, buffer, eo) + return ep_combine(buffer, eo) disp_mod = _DispatchMod().cuda() comb_mod = _CombineMod().cuda() @@ -383,7 +382,7 @@ def forward(self, eo): fwd_bwd_combine_fn = None captured_runners.clear() del g_disp, g_comb, disp_mod, comb_mod - del tokens_p, eo_p, buffer, handle, recv_tokens, recv_w, tokens, topk_w, expert_out + del tokens_p, eo_p, buffer, recv_tokens, recv_w, tokens, topk_w, expert_out gc.collect() torch.cuda.synchronize() # Release NCCL EP's borrowed comm before torch destroys it. diff --git a/examples/pytorch/ep/ep_moe.py b/examples/pytorch/ep/ep_moe.py index 934d88d8c7..70bf678f3c 100644 --- a/examples/pytorch/ep/ep_moe.py +++ b/examples/pytorch/ep/ep_moe.py @@ -15,7 +15,6 @@ import torch.distributed as dist from transformer_engine.pytorch.ep import ( - EpHandle, EpBuffer, ep_scope, ep_dispatch, @@ -147,20 +146,20 @@ def _run_layer(args, rank, world_size, ep_size, num_experts, num_local_experts, kernels_np[rank * num_local_experts : (rank + 1) * num_local_experts] ).to(device=device, dtype=torch.bfloat16) - handle = EpHandle( + buffer = EpBuffer( top_k=args.top_k, max_tokens_per_rank=T, recv_capacity_per_rank=recv_pr, hidden_dim=args.hidden, num_local_experts=num_local_experts, + ep_group=ep_group, ) - buffer = EpBuffer(handle) - recv_t, recv_w_out, _tc = ep_dispatch(handle, buffer, tokens, topk_idx, topk_w) + recv_t, recv_w_out, _tc = ep_dispatch(buffer, tokens, topk_idx, topk_w) expert_out = _batched_expert_linear(recv_t, kernels_local, num_local_experts) # Apply per-slot topk weighting before combine. expert_out = expert_out * recv_w_out.unsqueeze(-1).to(expert_out.dtype) - out = ep_combine(handle, buffer, expert_out) + out = ep_combine(buffer, expert_out) loss = 0.5 * (out.float() ** 2).sum() loss.backward() @@ -179,18 +178,18 @@ def _run_layer(args, rank, world_size, ep_size, num_experts, num_local_experts, torch.cuda.synchronize() dist.barrier() for _ in range(args.benchmark_warmup): - rt, rw, _tc = ep_dispatch(handle, buffer, tokens.detach(), topk_idx, topk_w) + rt, rw, _tc = ep_dispatch(buffer, tokens.detach(), topk_idx, topk_w) eo = _batched_expert_linear(rt, kernels_local, num_local_experts) eo = eo * rw.unsqueeze(-1).to(eo.dtype) - ep_combine(handle, buffer, eo) + ep_combine(buffer, eo) torch.cuda.synchronize() dist.barrier() t0 = time.perf_counter() for _ in range(args.benchmark_iters): - rt, rw, _tc = ep_dispatch(handle, buffer, tokens.detach(), topk_idx, topk_w) + rt, rw, _tc = ep_dispatch(buffer, tokens.detach(), topk_idx, topk_w) eo = _batched_expert_linear(rt, kernels_local, num_local_experts) eo = eo * rw.unsqueeze(-1).to(eo.dtype) - ep_combine(handle, buffer, eo) + ep_combine(buffer, eo) torch.cuda.synchronize() dt_ms = (time.perf_counter() - t0) * 1000.0 / args.benchmark_iters if rank == 0: diff --git a/tests/pytorch/distributed/run_ep.py b/tests/pytorch/distributed/run_ep.py index 7f74a454aa..3b343466b9 100644 --- a/tests/pytorch/distributed/run_ep.py +++ b/tests/pytorch/distributed/run_ep.py @@ -12,17 +12,20 @@ import torch.distributed as dist from transformer_engine.pytorch.ep import ( - EpHandle, EpBuffer, ep_bootstrap, ep_finalize, ep_prepare, ep_dispatch, ep_combine, + symm_mem_alloc, _ep_combine_raw, _ep_dispatch_raw, ) + +ZERO_COPY = False + # Must come after the transformer_engine import so libtransformer_engine.so is loaded. import transformer_engine_torch as tex # noqa: F401 @@ -104,21 +107,22 @@ def setUpClass(cls): max_tokens_per_rank=TOKENS_PER_RANK, recv_capacity_per_rank=cls.cfg.recv_capacity_per_rank, hidden_dim=HIDDEN_DIM, - zero_copy=True, + zero_copy=ZERO_COPY, ) - def _make_handle(self, alignment=0, top_k=TOP_K): - return EpHandle( + def _make_buffer(self, alignment=0, top_k=TOP_K): + return EpBuffer( top_k=top_k, max_tokens_per_rank=TOKENS_PER_RANK, recv_capacity_per_rank=self.cfg.recv_capacity_per_rank, hidden_dim=HIDDEN_DIM, num_local_experts=NUM_LOCAL_EXPERTS, alignment=alignment, + ep_group=self.ep_group, ) - def _make_buffers(self, dtype=torch.bfloat16): - """Allocate raw recv buffers + token_counts for the primitive tests.""" + def _make_raw_recv(self, dtype=torch.bfloat16): + """Raw recv tensors + token_counts for the primitive tests.""" rc = self.cfg.recv_capacity_per_rank return ( torch.empty(rc, HIDDEN_DIM, dtype=dtype, device=self.cfg.device), @@ -126,26 +130,23 @@ def _make_buffers(self, dtype=torch.bfloat16): torch.empty(NUM_LOCAL_EXPERTS, dtype=torch.int32, device=self.cfg.device), ) - def _make_ep_buffer(self, handle): - return EpBuffer(handle) - @staticmethod def _weighted(recv_tokens, recv_w): """fp32 per-slot weighting + cast back; matches the upstream combine input.""" mask = (recv_w != 0).to(torch.float32).unsqueeze(-1) return (recv_tokens.float() * recv_w.unsqueeze(-1).float() * mask).to(recv_tokens.dtype) - def _moe_step(self, handle, buffer, topk_idx, tokens, w): - recv_t, recv_w_out, _tc = ep_dispatch(handle, buffer, tokens, topk_idx, w) + def _moe_step(self, buffer, topk_idx, tokens, w): + recv_t, recv_w_out, _tc = ep_dispatch(buffer, tokens, topk_idx, w) eo = self._weighted(recv_t, recv_w_out) - return ep_combine(handle, buffer, eo) + return ep_combine(buffer, eo) # Prepare def test_primitive_prepare(self): - handle = self._make_handle() + buf = self._make_buffer() topk_idx, _toks, _w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) - token_counts = ep_prepare(handle, topk_idx) + token_counts = ep_prepare(buf, topk_idx) torch.cuda.synchronize() self.assertEqual(token_counts.shape, (NUM_LOCAL_EXPERTS,)) local = int(token_counts.sum().item()) @@ -156,13 +157,13 @@ def test_primitive_prepare(self): # Identity round-trip via raw primitives def test_primitive_dispatch_combine_identity(self): - handle = self._make_handle() + buf = self._make_buffer() topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) - recv_tokens, recv_w, _ = self._make_buffers() - ep_prepare(handle, topk_idx) - _ep_dispatch_raw(handle, topk_idx, tokens, w, recv_tokens, recv_w) + recv_tokens, recv_w, _ = self._make_raw_recv() + ep_prepare(buf, topk_idx) + _ep_dispatch_raw(buf, topk_idx, tokens, w, recv_tokens, recv_w) result = torch.empty_like(tokens) - _ep_combine_raw(handle, self._weighted(recv_tokens, recv_w), result) + _ep_combine_raw(buf, self._weighted(recv_tokens, recv_w), result) torch.cuda.synchronize() torch.testing.assert_close(result.float(), tokens.float(), atol=5e-2, rtol=5e-2) @@ -170,11 +171,10 @@ def test_primitive_dispatch_combine_identity(self): def test_dispatch_fwd_bwd(self): """0.5*||recv_tokens||^2 ; grad_tokens equals TOP_K * tokens.""" - handle = self._make_handle() - buffer = self._make_ep_buffer(handle) + buf = self._make_buffer() topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) tokens_p = tokens.detach().clone().requires_grad_(True) - recv_t, _recv_w, _tc = ep_dispatch(handle, buffer, tokens_p, topk_idx, w) + recv_t, _recv_w, _tc = ep_dispatch(buf, tokens_p, topk_idx, w) loss = 0.5 * (recv_t.float() ** 2).sum() loss.backward() torch.cuda.synchronize() @@ -184,11 +184,10 @@ def test_dispatch_fwd_bwd(self): def test_combine_fwd_bwd(self): """Full dispatch + combine fwd+bwd; identity inputs round-trip.""" - handle = self._make_handle() - buffer = self._make_ep_buffer(handle) + buf = self._make_buffer() topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) tokens_p = tokens.detach().clone().requires_grad_(True) - out = self._moe_step(handle, buffer, topk_idx, tokens_p, w) + out = self._moe_step(buf, topk_idx, tokens_p, w) loss = 0.5 * (out.float() ** 2).sum() loss.backward() torch.cuda.synchronize() @@ -198,14 +197,13 @@ def test_combine_fwd_bwd(self): # Multi-iter stability def test_dispatch_fwd_bwd_multiple_iterations(self): - """5 fwd+bwd iters on the same EpHandle + EpBuffer must be bit-stable.""" - handle = self._make_handle() - buffer = self._make_ep_buffer(handle) + """5 fwd+bwd iters on the same EpBuffer must be bit-stable.""" + buf = self._make_buffer() topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) def one_step(): tokens_p = tokens.detach().clone().requires_grad_(True) - out = self._moe_step(handle, buffer, topk_idx, tokens_p, w) + out = self._moe_step(buf, topk_idx, tokens_p, w) loss = 0.5 * (out.float() ** 2).sum() loss.backward() return out.detach().clone(), tokens_p.grad.detach().clone() @@ -222,22 +220,22 @@ def one_step(): def test_cuda_graph_capture(self): """Capture raw dispatch+combine into a CUDA graph; replay must be bit-stable.""" - handle = self._make_handle() + buf = self._make_buffer() topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) - recv_tokens, recv_w, _ = self._make_buffers() + recv_tokens, recv_w, _ = self._make_raw_recv() result = torch.empty_like(tokens) def step(): - ep_prepare(handle, topk_idx) - _ep_dispatch_raw(handle, topk_idx, tokens, w, recv_tokens, recv_w) - _ep_combine_raw(handle, self._weighted(recv_tokens, recv_w), result) + ep_prepare(buf, topk_idx) + _ep_dispatch_raw(buf, topk_idx, tokens, w, recv_tokens, recv_w) + _ep_combine_raw(buf, self._weighted(recv_tokens, recv_w), result) for _ in range(3): step() torch.cuda.synchronize() # Routing is fixed per layer; prepare runs once before capture. - ep_prepare(handle, topk_idx) + ep_prepare(buf, topk_idx) torch.cuda.synchronize() graph = torch.cuda.CUDAGraph() @@ -245,8 +243,8 @@ def step(): s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): with torch.cuda.graph(graph): - _ep_dispatch_raw(handle, topk_idx, tokens, w, recv_tokens, recv_w) - _ep_combine_raw(handle, self._weighted(recv_tokens, recv_w), result) + _ep_dispatch_raw(buf, topk_idx, tokens, w, recv_tokens, recv_w) + _ep_combine_raw(buf, self._weighted(recv_tokens, recv_w), result) torch.cuda.current_stream().wait_stream(s) torch.cuda.synchronize() @@ -259,15 +257,13 @@ def step(): # PP-1F1B handle isolation def test_pp_1f1b_two_handles(self): - """PP-1F1B interleave (F0 F1 B0 F2 B1 B2) over 3 per-microbatch handles.""" + """PP-1F1B interleave (F0 F1 B0 F2 B1 B2) over 3 per-microbatch buffers.""" T, H = TOKENS_PER_RANK, HIDDEN_DIM idx, _toks, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) scales = (0.13, 0.41, 0.77) - handles, buffers, tokens, tokens_p = [], [], [], [] + buffers, tokens, tokens_p = [], [], [] for s in scales: - h = self._make_handle() - handles.append(h) - buffers.append(self._make_ep_buffer(h)) + buffers.append(self._make_buffer()) t = torch.full( (T, H), s + self.cfg.rank * 0.01, dtype=torch.bfloat16, device=self.cfg.device ) @@ -277,7 +273,7 @@ def test_pp_1f1b_two_handles(self): recv = [None, None, None] def fwd(k): - recv[k], _, _ = ep_dispatch(handles[k], buffers[k], tokens_p[k], idx, w) + recv[k], _, _ = ep_dispatch(buffers[k], tokens_p[k], idx, w) def bwd(k): (0.5 * (recv[k].float() ** 2).sum()).backward() @@ -301,12 +297,12 @@ def bwd(k): # Input validation def test_topk_int32_raises_clear_error(self): - handle = self._make_handle() + buf = self._make_buffer() topk_idx_int32 = torch.zeros( TOKENS_PER_RANK, TOP_K, dtype=torch.int32, device=self.cfg.device ) with self.assertRaises(RuntimeError) as cm: - ep_prepare(handle, topk_idx_int32) + ep_prepare(buf, topk_idx_int32) msg = str(cm.exception) self.assertIn("topk_idx", msg) self.assertIn(".long()", msg) diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index bc4b3bb5d1..9e66d882b4 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -16,7 +16,6 @@ __all__ = [ - "EpHandle", "EpBuffer", "ep_bootstrap", "ep_finalize", @@ -203,24 +202,43 @@ def ep_scope( ep_finalize() -# Handle +# Buffer -class EpHandle: - """Routing context for one EP layer. Construct one per concurrently-live - microbatch (e.g. one per in-flight PP-1F1B step). +class EpBuffer: + """Per-microbatch EP layer state: routing handle + persistent payload slots. + + Owns the per-call ``handle_mem`` routing scratch and the payload buffers + consumed by :func:`ep_dispatch` / :func:`ep_combine`. Allocate one + EpBuffer per concurrently-in-flight call on the layer (one per PP-1F1B + microbatch); sharing across overlapping calls clobbers tensors the + earlier bwd still reads. Call ``record_stream`` from streams other than + the allocation stream. + + Cross-rank payload slots are symm-mem-backed when ``ep_bootstrap`` was + called with ``zero_copy=True`` (requires ``ep_group``); otherwise plain + HBM. """ __slots__ = ( + # routing "handle_mem", "top_k", "alignment", + # layer config "max_tokens_per_rank", "recv_capacity_per_rank", "hidden_dim", "num_local_experts", "payload_dtype", "device", + # payload slots + "recv_tokens", + "combine_in", + "recv_topk_weights", + "token_counts", + "grad_tokens", + "grad_topk_weights", ) def __init__( @@ -231,16 +249,15 @@ def __init__( hidden_dim: int, num_local_experts: int, alignment: int = 0, - device: Optional[torch.device] = None, + ep_group: Optional[dist.ProcessGroup] = None, payload_dtype: torch.dtype = torch.bfloat16, + device: Optional[torch.device] = None, ) -> None: if device is None: device = torch.device("cuda", torch.cuda.current_device()) alignment = int(alignment) if alignment > 1 and (alignment & (alignment - 1)) != 0: - raise ValueError( - f"alignment must be 0, 1, or a power of two (got {alignment})." - ) + raise ValueError(f"alignment must be 0, 1, or a power of two (got {alignment}).") self.top_k = int(top_k) self.alignment = alignment self.max_tokens_per_rank = int(max_tokens_per_rank) @@ -249,83 +266,43 @@ def __init__( self.num_local_experts = int(num_local_experts) self.payload_dtype = payload_dtype self.device = device + size_bytes = tex.ep_handle_mem_size(self.top_k, self.alignment) self.handle_mem = torch.empty(int(size_bytes), dtype=torch.uint8, device=device) - -# Buffer - - -class EpBuffer: - """Persistent payload and scratch buffers for one EP layer. - - All slots are plain HBM in TE 2.17 (the symm-mem IO fast path is planned - for a near-future release). - - Use one EpBuffer per concurrently-in-flight call on the layer (one per - PP-1F1B microbatch); sharing between an outstanding fwd and a later call - overwrites tensors the earlier bwd still reads. Call record_stream from - streams other than the allocation stream. - """ - - __slots__ = ( - "recv_tokens", - "combine_in", - "recv_topk_weights", - "token_counts", - "grad_tokens", - "grad_topk_weights", - ) - - def __init__( - self, - handle: EpHandle, - ep_group: Optional[dist.ProcessGroup] = None, - *, - device: Optional[torch.device] = None, - ) -> None: - """Allocate the persistent EP slots. - - Cross-rank slots are symm-mem-backed when ``ep_bootstrap`` was called - with ``zero_copy=True`` (requires ``ep_group``); otherwise plain HBM. - """ - if device is None: - device = torch.device("cuda", torch.cuda.current_device()) - recv_shape = (handle.recv_capacity_per_rank, handle.hidden_dim) - send_shape = (handle.max_tokens_per_rank, handle.hidden_dim) + recv_shape = (self.recv_capacity_per_rank, self.hidden_dim) + send_shape = (self.max_tokens_per_rank, self.hidden_dim) zero_copy = bool(tex.ep_get_zero_copy()) if zero_copy: if ep_group is None: raise ValueError("EpBuffer requires ep_group when ep_bootstrap(zero_copy=True).") - self.recv_tokens = symm_mem_alloc( - recv_shape, handle.payload_dtype, ep_group, device=device - ) - self.combine_in = symm_mem_alloc( - recv_shape, handle.payload_dtype, ep_group, device=device - ) + self.recv_tokens = symm_mem_alloc(recv_shape, payload_dtype, ep_group, device=device) + self.combine_in = symm_mem_alloc(recv_shape, payload_dtype, ep_group, device=device) self.recv_topk_weights = symm_mem_alloc( - (handle.recv_capacity_per_rank,), torch.float32, ep_group, device=device - ) - self.grad_tokens = symm_mem_alloc( - send_shape, handle.payload_dtype, ep_group, device=device + (self.recv_capacity_per_rank,), torch.float32, ep_group, device=device ) + self.grad_tokens = symm_mem_alloc(send_shape, payload_dtype, ep_group, device=device) else: - self.recv_tokens = torch.empty(recv_shape, dtype=handle.payload_dtype, device=device) - self.combine_in = torch.empty(recv_shape, dtype=handle.payload_dtype, device=device) + self.recv_tokens = torch.empty(recv_shape, dtype=payload_dtype, device=device) + self.combine_in = torch.empty(recv_shape, dtype=payload_dtype, device=device) self.recv_topk_weights = torch.empty( - handle.recv_capacity_per_rank, dtype=torch.float32, device=device + self.recv_capacity_per_rank, dtype=torch.float32, device=device ) - self.grad_tokens = torch.empty(send_shape, dtype=handle.payload_dtype, device=device) + self.grad_tokens = torch.empty(send_shape, dtype=payload_dtype, device=device) # Per-rank scratch; never cross-rank, plain HBM regardless of mode. - self.token_counts = torch.empty(handle.num_local_experts, dtype=torch.int32, device=device) + self.token_counts = torch.empty(self.num_local_experts, dtype=torch.int32, device=device) self.grad_topk_weights = torch.empty( - (handle.max_tokens_per_rank, handle.top_k), dtype=torch.float32, device=device + (self.max_tokens_per_rank, self.top_k), dtype=torch.float32, device=device ) @classmethod def from_external( cls, - handle: EpHandle, + top_k: int, + max_tokens_per_rank: int, + recv_capacity_per_rank: int, + hidden_dim: int, + num_local_experts: int, *, recv_tokens: torch.Tensor, combine_in: torch.Tensor, @@ -333,23 +310,27 @@ def from_external( grad_tokens: Optional[torch.Tensor] = None, token_counts: Optional[torch.Tensor] = None, grad_topk_weights: Optional[torch.Tensor] = None, + alignment: int = 0, + payload_dtype: torch.dtype = torch.bfloat16, device: Optional[torch.device] = None, ) -> "EpBuffer": - """Construct from caller-allocated buffers. + """Construct from caller-allocated payload buffers. - Useful for sharing a pre-allocated pool across layers/microbatches, and - for plugging in symm-mem-backed tensors once the zero-copy IO fast path - ships in a near-future release. Caller-supplied slots are validated - against the expected shape and dtype; ``None`` slots get a fresh HBM - allocation. + Useful for sharing a pre-allocated pool across layers/microbatches and + for plugging in symm-mem-backed tensors. Caller-supplied slots are + validated against the expected shape and dtype; ``None`` slots get a + fresh HBM allocation. ``handle_mem`` is always allocated fresh. """ if device is None: device = torch.device("cuda", torch.cuda.current_device()) - recv_shape = (handle.recv_capacity_per_rank, handle.hidden_dim) - send_shape = (handle.max_tokens_per_rank, handle.hidden_dim) - topk_shape = (handle.max_tokens_per_rank, handle.top_k) - recv_w_shape = (handle.recv_capacity_per_rank,) - counts_shape = (handle.num_local_experts,) + alignment = int(alignment) + if alignment > 1 and (alignment & (alignment - 1)) != 0: + raise ValueError(f"alignment must be 0, 1, or a power of two (got {alignment}).") + recv_shape = (recv_capacity_per_rank, hidden_dim) + send_shape = (max_tokens_per_rank, hidden_dim) + topk_shape = (max_tokens_per_rank, top_k) + recv_w_shape = (recv_capacity_per_rank,) + counts_shape = (num_local_experts,) def _check(t: torch.Tensor, name: str, shape: tuple, dtype: torch.dtype) -> torch.Tensor: if tuple(t.shape) != shape: @@ -359,17 +340,29 @@ def _check(t: torch.Tensor, name: str, shape: tuple, dtype: torch.dtype) -> torc return t inst = cls.__new__(cls) - inst.recv_tokens = _check(recv_tokens, "recv_tokens", recv_shape, handle.payload_dtype) - inst.combine_in = _check(combine_in, "combine_in", recv_shape, handle.payload_dtype) + inst.top_k = int(top_k) + inst.alignment = alignment + inst.max_tokens_per_rank = int(max_tokens_per_rank) + inst.recv_capacity_per_rank = int(recv_capacity_per_rank) + inst.hidden_dim = int(hidden_dim) + inst.num_local_experts = int(num_local_experts) + inst.payload_dtype = payload_dtype + inst.device = device + + size_bytes = tex.ep_handle_mem_size(inst.top_k, inst.alignment) + inst.handle_mem = torch.empty(int(size_bytes), dtype=torch.uint8, device=device) + + inst.recv_tokens = _check(recv_tokens, "recv_tokens", recv_shape, payload_dtype) + inst.combine_in = _check(combine_in, "combine_in", recv_shape, payload_dtype) inst.recv_topk_weights = ( _check(recv_topk_weights, "recv_topk_weights", recv_w_shape, torch.float32) if recv_topk_weights is not None else torch.empty(recv_w_shape, dtype=torch.float32, device=device) ) inst.grad_tokens = ( - _check(grad_tokens, "grad_tokens", send_shape, handle.payload_dtype) + _check(grad_tokens, "grad_tokens", send_shape, payload_dtype) if grad_tokens is not None - else torch.empty(send_shape, dtype=handle.payload_dtype, device=device) + else torch.empty(send_shape, dtype=payload_dtype, device=device) ) inst.token_counts = ( _check(token_counts, "token_counts", counts_shape, torch.int32) @@ -387,6 +380,7 @@ def record_stream(self, stream: torch.cuda.Stream) -> None: """Record stream as a user of all owned tensors so the caching allocator defers reclaim until stream has caught up.""" for t in ( + self.handle_mem, self.recv_tokens, self.combine_in, self.recv_topk_weights, @@ -502,19 +496,19 @@ def _(*args, **kw): # Non-autograd primitives -def ep_prepare(handle: EpHandle, topk_idx: torch.Tensor) -> torch.Tensor: - """AllGather the routing map; fills handle.handle_mem and returns token_counts - (int32, shape [num_local_experts]). topk_idx must be int64. +def ep_prepare(buffer: "EpBuffer", topk_idx: torch.Tensor) -> torch.Tensor: + """AllGather the routing map; fills ``buffer.handle_mem`` and returns + ``buffer.token_counts`` (int32, shape [num_local_experts]). topk_idx must + be int64. """ - token_counts = torch.empty(handle.num_local_experts, dtype=torch.int32, device=handle.device) torch.ops.transformer_engine_ep.prepare( - handle.handle_mem, handle.top_k, topk_idx, token_counts, handle.alignment + buffer.handle_mem, buffer.top_k, topk_idx, buffer.token_counts, buffer.alignment ) - return token_counts + return buffer.token_counts def _ep_dispatch_raw( - handle: EpHandle, + buffer: "EpBuffer", topk_idx: torch.Tensor, tokens: torch.Tensor, topk_weights: torch.Tensor, @@ -523,13 +517,13 @@ def _ep_dispatch_raw( ) -> None: """Raw dispatch; no autograd, no prepare. Caller must run ep_prepare first.""" tex.ep_dispatch( - handle.handle_mem, topk_idx, tokens, topk_weights, recv_tokens, recv_topk_weights + buffer.handle_mem, topk_idx, tokens, topk_weights, recv_tokens, recv_topk_weights ) -def _ep_combine_raw(handle: EpHandle, expert_out: torch.Tensor, result: torch.Tensor) -> None: +def _ep_combine_raw(buffer: "EpBuffer", expert_out: torch.Tensor, result: torch.Tensor) -> None: """Raw combine; no autograd. Caller pre-weights expert_out.""" - tex.ep_combine(handle.handle_mem, expert_out, result) + tex.ep_combine(buffer.handle_mem, expert_out, result) # autograd.Function wrappers @@ -681,7 +675,6 @@ def _reject_fp8(*tensors: torch.Tensor) -> None: def ep_dispatch( - handle: EpHandle, buffer: EpBuffer, tokens: torch.Tensor, topk_idx: torch.Tensor, @@ -689,15 +682,15 @@ def ep_dispatch( ): """Run prepare + dispatch with autograd. topk_idx must be int64. - Returns (recv_tokens, recv_topk_weights, token_counts); views into buffer's - persistent slots; consume them before the next ep_dispatch on the same - buffer or they get overwritten. token_counts is non-differentiable. + Returns (recv_tokens, recv_topk_weights, token_counts); views into the + buffer's persistent slots — consume them before the next ep_dispatch on + the same buffer or they get overwritten. token_counts is non-differentiable. """ _reject_fp8(tokens, buffer.recv_tokens) return _EpDispatch.apply( - handle.handle_mem, - handle.top_k, - handle.alignment, + buffer.handle_mem, + buffer.top_k, + buffer.alignment, buffer.recv_tokens, buffer.recv_topk_weights, buffer.token_counts, @@ -710,7 +703,6 @@ def ep_dispatch( def ep_combine( - handle: EpHandle, buffer: EpBuffer, expert_out: torch.Tensor, *, @@ -719,16 +711,16 @@ def ep_combine( """Combine expert outputs back to the source rank, with autograd. The caller must pre-apply the topk weighting to expert_out. - Result shape is (num_local_tokens, handle.hidden_dim); defaults to - handle.max_tokens_per_rank rows. + Result shape is (num_local_tokens, buffer.hidden_dim); defaults to + buffer.max_tokens_per_rank rows. """ _reject_fp8(expert_out, buffer.combine_in) if num_local_tokens is None: - num_local_tokens = handle.max_tokens_per_rank + num_local_tokens = buffer.max_tokens_per_rank return _EpCombine.apply( - handle.handle_mem, + buffer.handle_mem, buffer.combine_in, num_local_tokens, - handle.hidden_dim, + buffer.hidden_dim, expert_out, ) From e230f4b926825d2c9415a5673336bc59b44f2872 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 9 Jun 2026 17:39:37 -0700 Subject: [PATCH 04/18] EP PyTorch example: drop stale ep_group kwarg from EpBuffer call Signed-off-by: Phuong Nguyen --- examples/pytorch/ep/ep_moe.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/pytorch/ep/ep_moe.py b/examples/pytorch/ep/ep_moe.py index 70bf678f3c..ca7469951d 100644 --- a/examples/pytorch/ep/ep_moe.py +++ b/examples/pytorch/ep/ep_moe.py @@ -152,7 +152,6 @@ def _run_layer(args, rank, world_size, ep_size, num_experts, num_local_experts, recv_capacity_per_rank=recv_pr, hidden_dim=args.hidden, num_local_experts=num_local_experts, - ep_group=ep_group, ) recv_t, recv_w_out, _tc = ep_dispatch(buffer, tokens, topk_idx, topk_w) From 7c3b3179605b138cff7c5b3b755c71ebcdc27b0f Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 9 Jun 2026 17:45:14 -0700 Subject: [PATCH 05/18] EP PyTorch: drop ep_scope; ep_finalize is optional with atexit fallback Signed-off-by: Phuong Nguyen --- examples/pytorch/ep/ep_moe.py | 14 ++++++---- transformer_engine/pytorch/ep.py | 46 +++++--------------------------- 2 files changed, 15 insertions(+), 45 deletions(-) diff --git a/examples/pytorch/ep/ep_moe.py b/examples/pytorch/ep/ep_moe.py index ca7469951d..24e271f2f8 100644 --- a/examples/pytorch/ep/ep_moe.py +++ b/examples/pytorch/ep/ep_moe.py @@ -16,9 +16,10 @@ from transformer_engine.pytorch.ep import ( EpBuffer, - ep_scope, - ep_dispatch, + ep_bootstrap, ep_combine, + ep_dispatch, + ep_finalize, ) @@ -114,15 +115,18 @@ def main(): recv_pr = ep_size * T * args.top_k ep_group = dist.new_group(ranks=list(range(world_size)), backend="nccl") - with ep_scope( + ep_bootstrap( ep_group, num_experts=num_experts, max_tokens_per_rank=T, recv_capacity_per_rank=recv_pr, hidden_dim=args.hidden, - ): + ) + try: _run_layer(args, rank, world_size, ep_size, num_experts, num_local_experts, T, recv_pr, device) - dist.destroy_process_group() + finally: + ep_finalize() + dist.destroy_process_group() def _run_layer(args, rank, world_size, ep_size, num_experts, num_local_experts, T, recv_pr, device): diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index 9e66d882b4..b3612401bd 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -6,8 +6,7 @@ from __future__ import annotations import atexit -from contextlib import contextmanager -from typing import Iterator, Optional +from typing import Optional import torch import torch.distributed as dist @@ -19,7 +18,6 @@ "EpBuffer", "ep_bootstrap", "ep_finalize", - "ep_scope", "ep_dispatch", "ep_combine", "symm_mem_alloc", @@ -154,12 +152,12 @@ def ep_bootstrap( def ep_finalize() -> None: - """Explicit EP teardown; optional and idempotent. An atexit handler covers - normal shutdown; call this only before ``dist.destroy_process_group()``, - since the borrowed NCCL comm is invalid once the PG is destroyed. + """Optional explicit EP teardown; idempotent. - Propagates errors from the C++ teardown; use ``_atexit_finalize`` for the - best-effort interpreter-shutdown path. + An atexit handler covers normal interpreter shutdown, so most users do not + need to call this. Call it explicitly only before + ``dist.destroy_process_group()``, since the borrowed NCCL comm becomes + invalid once the PG is destroyed. """ global _BOOTSTRAPPED if not _BOOTSTRAPPED: @@ -170,38 +168,6 @@ def ep_finalize() -> None: _BOOTSTRAPPED = False -@contextmanager -def ep_scope( - ep_group: dist.ProcessGroup, - num_experts: int, - max_tokens_per_rank: int, - recv_capacity_per_rank: int, - hidden_dim: int, - max_num_sms: int = 0, - zero_copy: bool = False, - max_token_dtype: torch.dtype = torch.bfloat16, -) -> Iterator[None]: - """Context manager: ``ep_bootstrap`` on enter, ``ep_finalize`` on exit. - - Use when you tear down the EP process group yourself, so the borrowed NCCL - comm is released before ``dist.destroy_process_group()``. - """ - ep_bootstrap( - ep_group, - num_experts=num_experts, - max_tokens_per_rank=max_tokens_per_rank, - recv_capacity_per_rank=recv_capacity_per_rank, - hidden_dim=hidden_dim, - max_num_sms=max_num_sms, - zero_copy=zero_copy, - max_token_dtype=max_token_dtype, - ) - try: - yield - finally: - ep_finalize() - - # Buffer From 603cc533af459bc7aa9f58e77eae13d7d4e7c946 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 9 Jun 2026 17:57:59 -0700 Subject: [PATCH 06/18] EP PyTorch: restrict payload dtype to bf16; refresh stale ep.cpp comments; drop unused ep_group kwargs Signed-off-by: Phuong Nguyen --- examples/pytorch/ep/bench/ep_bench.py | 1 - tests/pytorch/distributed/run_ep.py | 1 - .../pytorch/csrc/extensions/ep.cpp | 11 ++++------ transformer_engine/pytorch/ep.py | 21 +++++++------------ 4 files changed, 12 insertions(+), 22 deletions(-) diff --git a/examples/pytorch/ep/bench/ep_bench.py b/examples/pytorch/ep/bench/ep_bench.py index 973703b710..82a5429f9d 100644 --- a/examples/pytorch/ep/bench/ep_bench.py +++ b/examples/pytorch/ep/bench/ep_bench.py @@ -182,7 +182,6 @@ def main(): recv_capacity_per_rank=recv_pr, hidden_dim=H, num_local_experts=num_local_experts, - ep_group=ep_group, ) tokens = tokens_hbm diff --git a/tests/pytorch/distributed/run_ep.py b/tests/pytorch/distributed/run_ep.py index 3b343466b9..0e29ca7eae 100644 --- a/tests/pytorch/distributed/run_ep.py +++ b/tests/pytorch/distributed/run_ep.py @@ -118,7 +118,6 @@ def _make_buffer(self, alignment=0, top_k=TOP_K): hidden_dim=HIDDEN_DIM, num_local_experts=NUM_LOCAL_EXPERTS, alignment=alignment, - ep_group=self.ep_group, ) def _make_raw_recv(self, dtype=torch.bfloat16): diff --git a/transformer_engine/pytorch/csrc/extensions/ep.cpp b/transformer_engine/pytorch/csrc/extensions/ep.cpp index 018b5addc0..46484b8b67 100644 --- a/transformer_engine/pytorch/csrc/extensions/ep.cpp +++ b/transformer_engine/pytorch/csrc/extensions/ep.cpp @@ -46,15 +46,12 @@ std::string g_ep_group_name; // NOLINT(runtime/string) // True while the EP backend holds a borrowed reference to torch's NCCL comm. bool g_ep_initialized = false; -// Zero-copy IO toggle. Placeholder for the symm-mem fast path; per-step ops -// always pass kNoWindow in this release regardless of the toggle. Wired up -// so the switch is a one-line change when the backend lands the fast path. -// Atomic so the Python-side toggle is safe against concurrent -// ep_dispatch/combine (which release the GIL). +// Zero-copy IO toggle captured at ep_initialize. Atomic so the Python-side +// toggle is safe against concurrent ep_dispatch/combine (which release the GIL). std::atomic g_zero_copy_enabled{false}; -// Per-step ops always pass kNoWindow in this release; the symm-mem IO path is -// planned for a near-future release. +// Sentinel returned by maybe_make_window when zero-copy is off or the tensor +// is not symm-mem-backed; the backend treats it as "no window, use staged copy". constexpr NVTECommWindow kNoWindow = {nullptr, 0}; // Resolve ``t`` to an NCCL symm-mem window for the zero-copy one-sided path. diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index b3612401bd..34cbe93a29 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -627,17 +627,12 @@ def backward(ctx, g_result): # type: ignore[override] # Public high-level wrappers -# FP8 dispatch is not yet supported by the common backend. -_FP8_DTYPES = (torch.float8_e4m3fn, torch.float8_e5m2) - - -def _reject_fp8(*tensors: torch.Tensor) -> None: - for t in tensors: - if t.dtype in _FP8_DTYPES: - raise NotImplementedError( - f"FP8 dispatch/combine not supported (got dtype={t.dtype}); " - "quantize outside the EP boundary." - ) +# NCCL EP currently only supports bfloat16 payload tensors. +def _require_bf16(name: str, t: torch.Tensor) -> None: + if t.dtype is not torch.bfloat16: + raise NotImplementedError( + f"NCCL EP currently supports only bfloat16 payloads; got {name}.dtype={t.dtype}." + ) def ep_dispatch( @@ -652,7 +647,7 @@ def ep_dispatch( buffer's persistent slots — consume them before the next ep_dispatch on the same buffer or they get overwritten. token_counts is non-differentiable. """ - _reject_fp8(tokens, buffer.recv_tokens) + _require_bf16("tokens", tokens) return _EpDispatch.apply( buffer.handle_mem, buffer.top_k, @@ -680,7 +675,7 @@ def ep_combine( Result shape is (num_local_tokens, buffer.hidden_dim); defaults to buffer.max_tokens_per_rank rows. """ - _reject_fp8(expert_out, buffer.combine_in) + _require_bf16("expert_out", expert_out) if num_local_tokens is None: num_local_tokens = buffer.max_tokens_per_rank return _EpCombine.apply( From 999ac378ec00358b066d70361fe8e1ba04d84246 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 9 Jun 2026 18:26:42 -0700 Subject: [PATCH 07/18] EP PyTorch: clear pylint warnings in ep.py (broad-except suppression, _-prefixed stub args, autograd docstrings) Signed-off-by: Phuong Nguyen --- transformer_engine/pytorch/ep.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index 34cbe93a29..d5368e6e47 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -96,7 +96,7 @@ def _atexit_finalize() -> None: if _BOOTSTRAPPED: try: tex.ep_finalize() - except Exception: + except Exception: # pylint: disable=broad-exception-caught import traceback traceback.print_exc() @@ -378,7 +378,7 @@ def _prepare_op( @_prepare_op.register_fake -def _(*args, **kw): +def _(*_args, **_kw): return None @@ -399,7 +399,7 @@ def _dispatch_op( @_dispatch_op.register_fake -def _(*args, **kw): +def _(*_args, **_kw): return None @@ -417,7 +417,7 @@ def _combine_op( @_combine_op.register_fake -def _(*args, **kw): +def _(*_args, **_kw): return None @@ -437,7 +437,7 @@ def _dispatch_bwd_op( @_dispatch_bwd_op.register_fake -def _(*args, **kw): +def _(*_args, **_kw): return None @@ -455,7 +455,7 @@ def _combine_bwd_op( @_combine_bwd_op.register_fake -def _(*args, **kw): +def _(*_args, **_kw): return None @@ -516,6 +516,7 @@ def forward( # type: ignore[override] tokens: torch.Tensor, topk_weights: torch.Tensor, ): + """Prepare + dispatch; stashes buffer slots + shapes for the bwd pass.""" torch.ops.transformer_engine_ep.prepare( handle_mem, top_k, topk_idx, token_counts, alignment ) @@ -545,6 +546,7 @@ def forward( # type: ignore[override] @staticmethod def backward(ctx, g_recv_tokens, g_recv_topk_weights, _g_token_counts): # type: ignore[override] + """Dispatch backward into the persistent grad_tokens/grad_topk_weights slots.""" device = ctx.handle_mem.device if g_recv_tokens is None: g_recv_tokens = torch.zeros( @@ -599,6 +601,7 @@ def forward( # type: ignore[override] hidden_dim: int, expert_out: torch.Tensor, ): + """Combine expert outputs; reuses combine_in as the grad slot for bwd.""" device = expert_out.device # Stage expert_out into the persistent combine_in slot (symm-mem-backed # in the zero-copy path); its storage is reused as grad_combine_in in bwd. @@ -611,6 +614,7 @@ def forward( # type: ignore[override] @staticmethod def backward(ctx, g_result): # type: ignore[override] + """Combine backward into combine_in storage; returned as grad of expert_out.""" grad_combine_in = ctx.combine_in if not g_result.is_contiguous(): g_result = g_result.contiguous() From dbc762449a244ec550311f354cdf19e2b6586e9f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Jun 2026 01:29:03 +0000 Subject: [PATCH 08/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/pytorch/ep/bench/ep_bench.py | 4 +--- examples/pytorch/ep/ep_moe.py | 9 ++++----- transformer_engine/pytorch/csrc/extensions/ep.cpp | 4 ++-- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/examples/pytorch/ep/bench/ep_bench.py b/examples/pytorch/ep/bench/ep_bench.py index 82a5429f9d..81f5b83883 100644 --- a/examples/pytorch/ep/bench/ep_bench.py +++ b/examples/pytorch/ep/bench/ep_bench.py @@ -208,9 +208,7 @@ def main(): eo_p = recv_tokens.detach().clone().requires_grad_(True) # Stand-in callables; the cuda-graph branch below swaps in graphed versions. - fwd_bwd_dispatch_fn = lambda x: ep_dispatch(buffer, x, topk_idx, topk_w)[ # noqa: E731 - 0 - ] + fwd_bwd_dispatch_fn = lambda x: ep_dispatch(buffer, x, topk_idx, topk_w)[0] # noqa: E731 fwd_bwd_combine_fn = lambda eo: ep_combine(buffer, eo) # noqa: E731 def _dispatch_raw(): diff --git a/examples/pytorch/ep/ep_moe.py b/examples/pytorch/ep/ep_moe.py index 24e271f2f8..f72912301b 100644 --- a/examples/pytorch/ep/ep_moe.py +++ b/examples/pytorch/ep/ep_moe.py @@ -123,7 +123,9 @@ def main(): hidden_dim=args.hidden, ) try: - _run_layer(args, rank, world_size, ep_size, num_experts, num_local_experts, T, recv_pr, device) + _run_layer( + args, rank, world_size, ep_size, num_experts, num_local_experts, T, recv_pr, device + ) finally: ep_finalize() dist.destroy_process_group() @@ -196,10 +198,7 @@ def _run_layer(args, rank, world_size, ep_size, num_experts, num_local_experts, torch.cuda.synchronize() dt_ms = (time.perf_counter() - t0) * 1000.0 / args.benchmark_iters if rank == 0: - print( - f"[ep_moe --benchmark] HBM: {dt_ms:.3f} ms/iter " - f"(iters={args.benchmark_iters})" - ) + print(f"[ep_moe --benchmark] HBM: {dt_ms:.3f} ms/iter (iters={args.benchmark_iters})") if args.check: # All-gather inputs/outputs/grads for a global reference comparison. diff --git a/transformer_engine/pytorch/csrc/extensions/ep.cpp b/transformer_engine/pytorch/csrc/extensions/ep.cpp index 46484b8b67..8624605edd 100644 --- a/transformer_engine/pytorch/csrc/extensions/ep.cpp +++ b/transformer_engine/pytorch/csrc/extensions/ep.cpp @@ -102,8 +102,8 @@ bool ep_get_zero_copy() { return g_zero_copy_enabled.load(std::memory_order_rela void ep_initialize(uintptr_t comm_ptr, const std::string& group_name, int64_t num_experts, int64_t max_tokens_per_rank, int64_t max_recv_tokens_per_rank, - int64_t hidden_dim, int64_t max_num_sms, - pybind11::object max_token_dtype, bool zero_copy) { + int64_t hidden_dim, int64_t max_num_sms, pybind11::object max_token_dtype, + bool zero_copy) { NVTE_CHECK(!group_name.empty(), "group_name must be non-empty (used for symm-mem lookup)"); NVTE_CHECK(comm_ptr != 0, "comm_ptr must be non-null (torch NCCL host comm pointer)"); NVTE_CHECK(!g_ep_initialized, "ep_initialize called twice without ep_finalize"); From 4742ca8c569ac6dce0ef17da4ea78f2e2fb88d8d Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 9 Jun 2026 20:28:12 -0700 Subject: [PATCH 09/18] EP PyTorch: skip combine_in staging copy in non-zero-copy mode Signed-off-by: Phuong Nguyen --- transformer_engine/pytorch/ep.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index d5368e6e47..64a9a69df7 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -205,6 +205,7 @@ class EpBuffer: "token_counts", "grad_tokens", "grad_topk_weights", + "zero_copy", ) def __init__( @@ -239,6 +240,7 @@ def __init__( recv_shape = (self.recv_capacity_per_rank, self.hidden_dim) send_shape = (self.max_tokens_per_rank, self.hidden_dim) zero_copy = bool(tex.ep_get_zero_copy()) + self.zero_copy = zero_copy if zero_copy: if ep_group is None: raise ValueError("EpBuffer requires ep_group when ep_bootstrap(zero_copy=True).") @@ -314,6 +316,7 @@ def _check(t: torch.Tensor, name: str, shape: tuple, dtype: torch.dtype) -> torc inst.num_local_experts = int(num_local_experts) inst.payload_dtype = payload_dtype inst.device = device + inst.zero_copy = bool(tex.ep_get_zero_copy()) size_bytes = tex.ep_handle_mem_size(inst.top_k, inst.alignment) inst.handle_mem = torch.empty(int(size_bytes), dtype=torch.uint8, device=device) @@ -599,17 +602,23 @@ def forward( # type: ignore[override] combine_in: torch.Tensor, num_local_tokens: int, hidden_dim: int, + zero_copy: bool, expert_out: torch.Tensor, ): - """Combine expert outputs; reuses combine_in as the grad slot for bwd.""" + """Combine expert outputs; combine_in storage is reused as the grad slot in bwd.""" device = expert_out.device - # Stage expert_out into the persistent combine_in slot (symm-mem-backed - # in the zero-copy path); its storage is reused as grad_combine_in in bwd. - combine_in.copy_(expert_out) + # Zero-copy mode: peers read from combine_in via symm-mem, so it must + # hold the expert outputs; stage expert_out into it unless aliased. + # Otherwise the kernel reads expert_out directly, no copy needed. + if zero_copy and combine_in.data_ptr() != expert_out.data_ptr(): + combine_in.copy_(expert_out) + kernel_in = combine_in + else: + kernel_in = expert_out result = torch.empty(num_local_tokens, hidden_dim, dtype=expert_out.dtype, device=device) - torch.ops.transformer_engine_ep.combine(handle_mem, combine_in, result) + torch.ops.transformer_engine_ep.combine(handle_mem, kernel_in, result) ctx.handle_mem = handle_mem - ctx.combine_in = combine_in # reused as grad_combine_in in bwd + ctx.combine_in = combine_in # used as grad_combine_in in bwd return result @staticmethod @@ -624,6 +633,7 @@ def backward(ctx, g_result): # type: ignore[override] None, # combine_in None, # num_local_tokens None, # hidden_dim + None, # zero_copy grad_combine_in, ) @@ -687,5 +697,6 @@ def ep_combine( buffer.combine_in, num_local_tokens, buffer.hidden_dim, + buffer.zero_copy, expert_out, ) From 44e0a510f5f4bf9b8ce37e3a27cc9d10f12b9215 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 9 Jun 2026 20:46:55 -0700 Subject: [PATCH 10/18] EP PyTorch: gate symm-mem slot allocation on zero-copy; require expert_out to alias combine_in in zero-copy Signed-off-by: Phuong Nguyen --- transformer_engine/pytorch/ep.py | 151 ++++++++++++++++++++----------- 1 file changed, 99 insertions(+), 52 deletions(-) diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index 64a9a69df7..264d4e51f9 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -241,6 +241,8 @@ def __init__( send_shape = (self.max_tokens_per_rank, self.hidden_dim) zero_copy = bool(tex.ep_get_zero_copy()) self.zero_copy = zero_copy + # Cross-rank slots are pre-allocated as symm-mem only when zero-copy + # is on; non-zero-copy mode allocates plain HBM per call in the ops. if zero_copy: if ep_group is None: raise ValueError("EpBuffer requires ep_group when ep_bootstrap(zero_copy=True).") @@ -251,12 +253,10 @@ def __init__( ) self.grad_tokens = symm_mem_alloc(send_shape, payload_dtype, ep_group, device=device) else: - self.recv_tokens = torch.empty(recv_shape, dtype=payload_dtype, device=device) - self.combine_in = torch.empty(recv_shape, dtype=payload_dtype, device=device) - self.recv_topk_weights = torch.empty( - self.recv_capacity_per_rank, dtype=torch.float32, device=device - ) - self.grad_tokens = torch.empty(send_shape, dtype=payload_dtype, device=device) + self.recv_tokens = None + self.combine_in = None + self.recv_topk_weights = None + self.grad_tokens = None # Per-rank scratch; never cross-rank, plain HBM regardless of mode. self.token_counts = torch.empty(self.num_local_experts, dtype=torch.int32, device=device) self.grad_topk_weights = torch.empty( @@ -272,8 +272,8 @@ def from_external( hidden_dim: int, num_local_experts: int, *, - recv_tokens: torch.Tensor, - combine_in: torch.Tensor, + recv_tokens: Optional[torch.Tensor] = None, + combine_in: Optional[torch.Tensor] = None, recv_topk_weights: Optional[torch.Tensor] = None, grad_tokens: Optional[torch.Tensor] = None, token_counts: Optional[torch.Tensor] = None, @@ -284,10 +284,10 @@ def from_external( ) -> "EpBuffer": """Construct from caller-allocated payload buffers. - Useful for sharing a pre-allocated pool across layers/microbatches and - for plugging in symm-mem-backed tensors. Caller-supplied slots are - validated against the expected shape and dtype; ``None`` slots get a - fresh HBM allocation. ``handle_mem`` is always allocated fresh. + In zero-copy mode recv_tokens, combine_in, recv_topk_weights, and + grad_tokens must be supplied and symm-mem-backed; in non-zero-copy + mode they default to None and the ops allocate per call. handle_mem + is always allocated fresh. """ if device is None: device = torch.device("cuda", torch.cuda.current_device()) @@ -321,18 +321,28 @@ def _check(t: torch.Tensor, name: str, shape: tuple, dtype: torch.dtype) -> torc size_bytes = tex.ep_handle_mem_size(inst.top_k, inst.alignment) inst.handle_mem = torch.empty(int(size_bytes), dtype=torch.uint8, device=device) - inst.recv_tokens = _check(recv_tokens, "recv_tokens", recv_shape, payload_dtype) - inst.combine_in = _check(combine_in, "combine_in", recv_shape, payload_dtype) - inst.recv_topk_weights = ( - _check(recv_topk_weights, "recv_topk_weights", recv_w_shape, torch.float32) - if recv_topk_weights is not None - else torch.empty(recv_w_shape, dtype=torch.float32, device=device) - ) - inst.grad_tokens = ( - _check(grad_tokens, "grad_tokens", send_shape, payload_dtype) - if grad_tokens is not None - else torch.empty(send_shape, dtype=payload_dtype, device=device) - ) + if inst.zero_copy: + if ( + recv_tokens is None + or combine_in is None + or recv_topk_weights is None + or grad_tokens is None + ): + raise ValueError( + "EpBuffer.from_external requires recv_tokens, combine_in, recv_topk_weights, " + "and grad_tokens (all symm-mem-backed) when zero-copy is enabled." + ) + inst.recv_tokens = _check(recv_tokens, "recv_tokens", recv_shape, payload_dtype) + inst.combine_in = _check(combine_in, "combine_in", recv_shape, payload_dtype) + inst.recv_topk_weights = _check( + recv_topk_weights, "recv_topk_weights", recv_w_shape, torch.float32 + ) + inst.grad_tokens = _check(grad_tokens, "grad_tokens", send_shape, payload_dtype) + else: + inst.recv_tokens = None + inst.combine_in = None + inst.recv_topk_weights = None + inst.grad_tokens = None inst.token_counts = ( _check(token_counts, "token_counts", counts_shape, torch.int32) if token_counts is not None @@ -357,7 +367,8 @@ def record_stream(self, stream: torch.cuda.Stream) -> None: self.token_counts, self.grad_topk_weights, ): - t.record_stream(stream) + if t is not None: + t.record_stream(stream) # torch.library custom ops (so they don't graph-break under torch.compile) @@ -549,7 +560,8 @@ def forward( # type: ignore[override] @staticmethod def backward(ctx, g_recv_tokens, g_recv_topk_weights, _g_token_counts): # type: ignore[override] - """Dispatch backward into the persistent grad_tokens/grad_topk_weights slots.""" + """Dispatch backward; grad_tokens uses the buffer's symm-mem slot in + zero-copy mode or a fresh HBM tensor otherwise.""" device = ctx.handle_mem.device if g_recv_tokens is None: g_recv_tokens = torch.zeros( @@ -561,8 +573,13 @@ def backward(ctx, g_recv_tokens, g_recv_topk_weights, _g_token_counts): # type: g_recv_tokens = g_recv_tokens.contiguous() if not g_recv_topk_weights.is_contiguous(): g_recv_topk_weights = g_recv_topk_weights.contiguous() - # Narrow the persistent slots to this call's flattened leading dim. - grad_tokens = ctx.grad_tokens_buf.narrow(0, 0, ctx.tokens_T_flat) + if ctx.grad_tokens_buf is not None: + # Zero-copy: narrow the persistent symm-mem slot to this call's leading dim. + grad_tokens = ctx.grad_tokens_buf.narrow(0, 0, ctx.tokens_T_flat) + else: + grad_tokens = torch.empty( + ctx.tokens_T_flat, ctx.hidden_dim, dtype=ctx.tokens_dtype, device=device + ) grad_topk_weights = ctx.grad_topk_weights_buf.narrow(0, 0, ctx.topk_T_flat) torch.ops.transformer_engine_ep.dispatch_bwd( ctx.handle_mem, @@ -590,43 +607,59 @@ def backward(ctx, g_recv_tokens, g_recv_topk_weights, _g_token_counts): # type: class _EpCombine(torch.autograd.Function): - """Autograd-aware combine. combine_in is reused as grad_combine_in in bwd; - fwd/bwd share handle_mem so don't re-run ep_prepare between them. Caller - must pre-apply the topk weighting to expert_out. + """Autograd-aware combine. Zero-copy mode requires expert_out to alias + buffer.combine_in (no implicit staging), and that storage is reused as + the grad slot in bwd. Non-zero-copy mode reads expert_out directly and + allocates the bwd grad slot fresh. Caller pre-applies topk weighting. """ @staticmethod def forward( # type: ignore[override] ctx, handle_mem: torch.Tensor, - combine_in: torch.Tensor, + combine_in: Optional[torch.Tensor], num_local_tokens: int, hidden_dim: int, zero_copy: bool, expert_out: torch.Tensor, ): - """Combine expert outputs; combine_in storage is reused as the grad slot in bwd.""" + """Combine fwd; zero-copy requires expert_out to alias combine_in.""" + if zero_copy: + if combine_in is None: + raise RuntimeError( + "ep_combine: zero-copy mode requires buffer.combine_in to be allocated." + ) + if combine_in.data_ptr() != expert_out.data_ptr(): + raise RuntimeError( + "ep_combine: zero-copy mode requires expert_out to alias " + "buffer.combine_in (write expert outputs directly into that slot; " + "no implicit copy)." + ) device = expert_out.device - # Zero-copy mode: peers read from combine_in via symm-mem, so it must - # hold the expert outputs; stage expert_out into it unless aliased. - # Otherwise the kernel reads expert_out directly, no copy needed. - if zero_copy and combine_in.data_ptr() != expert_out.data_ptr(): - combine_in.copy_(expert_out) - kernel_in = combine_in - else: - kernel_in = expert_out result = torch.empty(num_local_tokens, hidden_dim, dtype=expert_out.dtype, device=device) - torch.ops.transformer_engine_ep.combine(handle_mem, kernel_in, result) + torch.ops.transformer_engine_ep.combine(handle_mem, expert_out, result) ctx.handle_mem = handle_mem - ctx.combine_in = combine_in # used as grad_combine_in in bwd + ctx.combine_in = combine_in # None in non-zero-copy; reused as grad slot otherwise + ctx.zero_copy = zero_copy + ctx.recv_capacity = expert_out.shape[0] + ctx.hidden_dim = expert_out.shape[-1] + ctx.expert_out_dtype = expert_out.dtype return result @staticmethod def backward(ctx, g_result): # type: ignore[override] - """Combine backward into combine_in storage; returned as grad of expert_out.""" - grad_combine_in = ctx.combine_in + """Combine bwd; writes into combine_in in zero-copy or a fresh slot otherwise.""" if not g_result.is_contiguous(): g_result = g_result.contiguous() + if ctx.zero_copy: + grad_combine_in = ctx.combine_in + else: + grad_combine_in = torch.empty( + ctx.recv_capacity, + ctx.hidden_dim, + dtype=ctx.expert_out_dtype, + device=ctx.handle_mem.device, + ) torch.ops.transformer_engine_ep.combine_bwd(ctx.handle_mem, g_result, grad_combine_in) return ( None, # handle_mem @@ -657,17 +690,30 @@ def ep_dispatch( ): """Run prepare + dispatch with autograd. topk_idx must be int64. - Returns (recv_tokens, recv_topk_weights, token_counts); views into the - buffer's persistent slots — consume them before the next ep_dispatch on - the same buffer or they get overwritten. token_counts is non-differentiable. + Returns (recv_tokens, recv_topk_weights, token_counts). In zero-copy mode + recv_tokens / recv_topk_weights alias the buffer's persistent symm-mem + slots; otherwise they are freshly allocated. token_counts is non-diff. """ _require_bf16("tokens", tokens) + if buffer.zero_copy: + recv_tokens = buffer.recv_tokens + recv_topk_weights = buffer.recv_topk_weights + else: + recv_tokens = torch.empty( + buffer.recv_capacity_per_rank, + buffer.hidden_dim, + dtype=buffer.payload_dtype, + device=buffer.device, + ) + recv_topk_weights = torch.empty( + buffer.recv_capacity_per_rank, dtype=torch.float32, device=buffer.device + ) return _EpDispatch.apply( buffer.handle_mem, buffer.top_k, buffer.alignment, - buffer.recv_tokens, - buffer.recv_topk_weights, + recv_tokens, + recv_topk_weights, buffer.token_counts, buffer.grad_tokens, buffer.grad_topk_weights, @@ -683,8 +729,9 @@ def ep_combine( *, num_local_tokens: Optional[int] = None, ): - """Combine expert outputs back to the source rank, with autograd. The - caller must pre-apply the topk weighting to expert_out. + """Combine expert outputs back to the source rank, with autograd. Caller + pre-applies topk weighting. Zero-copy mode requires expert_out to alias + buffer.combine_in (write expert outputs into that slot directly). Result shape is (num_local_tokens, buffer.hidden_dim); defaults to buffer.max_tokens_per_rank rows. From e6ec4b65893aba3296f8d4652dfc9db85ccedb2f Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 9 Jun 2026 21:39:55 -0700 Subject: [PATCH 11/18] EP PyTorch: rename buffer slots to dispatch_/combine_symm_buf; drop grad_tokens/grad_topk_weights; alias-check bwd grads in zero-copy Signed-off-by: Phuong Nguyen --- transformer_engine/pytorch/ep.py | 211 ++++++++++++++++--------------- 1 file changed, 108 insertions(+), 103 deletions(-) diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index 264d4e51f9..ec99e1250a 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -198,13 +198,15 @@ class EpBuffer: "num_local_experts", "payload_dtype", "device", - # payload slots - "recv_tokens", - "combine_in", - "recv_topk_weights", + # Symm-mem slots (zero-copy only). Each is reused across fwd and bwd: + # dispatch_symm_buf: fwd out (recv_tokens) / bwd in (g_recv_tokens) + # dispatch_w_symm_buf: fwd out (recv_topk_w) / bwd in (g_recv_topk_w) + # combine_symm_buf: fwd in (expert_out) / bwd out (g_expert_out) + "dispatch_symm_buf", + "dispatch_w_symm_buf", + "combine_symm_buf", + # Per-rank scratch (always HBM). "token_counts", - "grad_tokens", - "grad_topk_weights", "zero_copy", ) @@ -238,30 +240,26 @@ def __init__( self.handle_mem = torch.empty(int(size_bytes), dtype=torch.uint8, device=device) recv_shape = (self.recv_capacity_per_rank, self.hidden_dim) - send_shape = (self.max_tokens_per_rank, self.hidden_dim) zero_copy = bool(tex.ep_get_zero_copy()) self.zero_copy = zero_copy - # Cross-rank slots are pre-allocated as symm-mem only when zero-copy - # is on; non-zero-copy mode allocates plain HBM per call in the ops. if zero_copy: if ep_group is None: raise ValueError("EpBuffer requires ep_group when ep_bootstrap(zero_copy=True).") - self.recv_tokens = symm_mem_alloc(recv_shape, payload_dtype, ep_group, device=device) - self.combine_in = symm_mem_alloc(recv_shape, payload_dtype, ep_group, device=device) - self.recv_topk_weights = symm_mem_alloc( + self.dispatch_symm_buf = symm_mem_alloc( + recv_shape, payload_dtype, ep_group, device=device + ) + self.dispatch_w_symm_buf = symm_mem_alloc( (self.recv_capacity_per_rank,), torch.float32, ep_group, device=device ) - self.grad_tokens = symm_mem_alloc(send_shape, payload_dtype, ep_group, device=device) + self.combine_symm_buf = symm_mem_alloc( + recv_shape, payload_dtype, ep_group, device=device + ) else: - self.recv_tokens = None - self.combine_in = None - self.recv_topk_weights = None - self.grad_tokens = None - # Per-rank scratch; never cross-rank, plain HBM regardless of mode. + self.dispatch_symm_buf = None + self.dispatch_w_symm_buf = None + self.combine_symm_buf = None + # token_counts is local-only routing scratch; always plain HBM. self.token_counts = torch.empty(self.num_local_experts, dtype=torch.int32, device=device) - self.grad_topk_weights = torch.empty( - (self.max_tokens_per_rank, self.top_k), dtype=torch.float32, device=device - ) @classmethod def from_external( @@ -272,22 +270,20 @@ def from_external( hidden_dim: int, num_local_experts: int, *, - recv_tokens: Optional[torch.Tensor] = None, - combine_in: Optional[torch.Tensor] = None, - recv_topk_weights: Optional[torch.Tensor] = None, - grad_tokens: Optional[torch.Tensor] = None, + dispatch_symm_buf: Optional[torch.Tensor] = None, + dispatch_w_symm_buf: Optional[torch.Tensor] = None, + combine_symm_buf: Optional[torch.Tensor] = None, token_counts: Optional[torch.Tensor] = None, - grad_topk_weights: Optional[torch.Tensor] = None, alignment: int = 0, payload_dtype: torch.dtype = torch.bfloat16, device: Optional[torch.device] = None, ) -> "EpBuffer": - """Construct from caller-allocated payload buffers. + """Construct from caller-allocated buffers. - In zero-copy mode recv_tokens, combine_in, recv_topk_weights, and - grad_tokens must be supplied and symm-mem-backed; in non-zero-copy - mode they default to None and the ops allocate per call. handle_mem - is always allocated fresh. + In zero-copy mode dispatch_symm_buf, dispatch_w_symm_buf, and + combine_symm_buf must all be supplied and symm-mem-backed; in + non-zero-copy mode they must all be None (ops allocate per call). + handle_mem is always allocated fresh. """ if device is None: device = torch.device("cuda", torch.cuda.current_device()) @@ -295,8 +291,6 @@ def from_external( if alignment > 1 and (alignment & (alignment - 1)) != 0: raise ValueError(f"alignment must be 0, 1, or a power of two (got {alignment}).") recv_shape = (recv_capacity_per_rank, hidden_dim) - send_shape = (max_tokens_per_rank, hidden_dim) - topk_shape = (max_tokens_per_rank, top_k) recv_w_shape = (recv_capacity_per_rank,) counts_shape = (num_local_experts,) @@ -323,36 +317,41 @@ def _check(t: torch.Tensor, name: str, shape: tuple, dtype: torch.dtype) -> torc if inst.zero_copy: if ( - recv_tokens is None - or combine_in is None - or recv_topk_weights is None - or grad_tokens is None + dispatch_symm_buf is None + or dispatch_w_symm_buf is None + or combine_symm_buf is None ): raise ValueError( - "EpBuffer.from_external requires recv_tokens, combine_in, recv_topk_weights, " - "and grad_tokens (all symm-mem-backed) when zero-copy is enabled." + "EpBuffer.from_external: zero-copy mode requires dispatch_symm_buf, " + "dispatch_w_symm_buf, and combine_symm_buf (all symm-mem-backed)." ) - inst.recv_tokens = _check(recv_tokens, "recv_tokens", recv_shape, payload_dtype) - inst.combine_in = _check(combine_in, "combine_in", recv_shape, payload_dtype) - inst.recv_topk_weights = _check( - recv_topk_weights, "recv_topk_weights", recv_w_shape, torch.float32 + inst.dispatch_symm_buf = _check( + dispatch_symm_buf, "dispatch_symm_buf", recv_shape, payload_dtype + ) + inst.dispatch_w_symm_buf = _check( + dispatch_w_symm_buf, "dispatch_w_symm_buf", recv_w_shape, torch.float32 + ) + inst.combine_symm_buf = _check( + combine_symm_buf, "combine_symm_buf", recv_shape, payload_dtype ) - inst.grad_tokens = _check(grad_tokens, "grad_tokens", send_shape, payload_dtype) else: - inst.recv_tokens = None - inst.combine_in = None - inst.recv_topk_weights = None - inst.grad_tokens = None + if ( + dispatch_symm_buf is not None + or dispatch_w_symm_buf is not None + or combine_symm_buf is not None + ): + raise ValueError( + "EpBuffer.from_external: dispatch_symm_buf / dispatch_w_symm_buf / " + "combine_symm_buf are only used in zero-copy mode." + ) + inst.dispatch_symm_buf = None + inst.dispatch_w_symm_buf = None + inst.combine_symm_buf = None inst.token_counts = ( _check(token_counts, "token_counts", counts_shape, torch.int32) if token_counts is not None else torch.empty(counts_shape, dtype=torch.int32, device=device) ) - inst.grad_topk_weights = ( - _check(grad_topk_weights, "grad_topk_weights", topk_shape, torch.float32) - if grad_topk_weights is not None - else torch.empty(topk_shape, dtype=torch.float32, device=device) - ) return inst def record_stream(self, stream: torch.cuda.Stream) -> None: @@ -360,12 +359,10 @@ def record_stream(self, stream: torch.cuda.Stream) -> None: defers reclaim until stream has caught up.""" for t in ( self.handle_mem, - self.recv_tokens, - self.combine_in, - self.recv_topk_weights, - self.grad_tokens, + self.dispatch_symm_buf, + self.dispatch_w_symm_buf, + self.combine_symm_buf, self.token_counts, - self.grad_topk_weights, ): if t is not None: t.record_stream(stream) @@ -510,9 +507,10 @@ def _ep_combine_raw(buffer: "EpBuffer", expert_out: torch.Tensor, result: torch. class _EpDispatch(torch.autograd.Function): - """Autograd-aware prepare + dispatch. Fwd/bwd share handle_mem and the - EpBuffer slots; do not re-run ep_prepare between them and do not share - EpBuffer with another in-flight call (see EpBuffer). + """Autograd-aware prepare + dispatch. Fwd produces recv_tokens (alias of + dispatch_symm_buf in zero-copy, fresh otherwise). Zero-copy bwd requires + the incoming grads to alias dispatch_symm_buf / dispatch_w_symm_buf + (no implicit staging). Fwd/bwd share handle_mem; do not re-run ep_prepare. """ @staticmethod @@ -521,16 +519,15 @@ def forward( # type: ignore[override] handle_mem: torch.Tensor, top_k: int, alignment: int, + zero_copy: bool, recv_tokens: torch.Tensor, recv_topk_weights: torch.Tensor, token_counts: torch.Tensor, - grad_tokens_buf: torch.Tensor, - grad_topk_weights_buf: torch.Tensor, topk_idx: torch.Tensor, tokens: torch.Tensor, topk_weights: torch.Tensor, ): - """Prepare + dispatch; stashes buffer slots + shapes for the bwd pass.""" + """Prepare + dispatch; saves shapes for the bwd pass.""" torch.ops.transformer_engine_ep.prepare( handle_mem, top_k, topk_idx, token_counts, alignment ) @@ -543,14 +540,18 @@ def forward( # type: ignore[override] recv_topk_weights, ) ctx.handle_mem = handle_mem - ctx.grad_tokens_buf = grad_tokens_buf - ctx.grad_topk_weights_buf = grad_topk_weights_buf + ctx.zero_copy = zero_copy + # Stash the symm-mem slot pointers so bwd can enforce alias of the + # grad inputs. In non-zero-copy mode the slots are fresh per call; + # no enforcement is meaningful, so leave the pointers as None. + ctx.dispatch_symm_ptr = recv_tokens.data_ptr() if zero_copy else None + ctx.dispatch_w_symm_ptr = recv_topk_weights.data_ptr() if zero_copy else None ctx.tokens_shape = tokens.shape ctx.tokens_dtype = tokens.dtype ctx.topk_weights_shape = topk_weights.shape - ctx.topk_weights_dtype = topk_weights.dtype ctx.tokens_T_flat = tokens.numel() // tokens.shape[-1] ctx.topk_T_flat = topk_weights.numel() // topk_weights.shape[-1] + ctx.top_k = topk_weights.shape[-1] ctx.recv_capacity = recv_tokens.shape[0] ctx.hidden_dim = tokens.shape[-1] ctx.mark_non_differentiable(token_counts) @@ -560,8 +561,7 @@ def forward( # type: ignore[override] @staticmethod def backward(ctx, g_recv_tokens, g_recv_topk_weights, _g_token_counts): # type: ignore[override] - """Dispatch backward; grad_tokens uses the buffer's symm-mem slot in - zero-copy mode or a fresh HBM tensor otherwise.""" + """Dispatch bwd; in zero-copy the grad inputs must alias the symm-mem slots.""" device = ctx.handle_mem.device if g_recv_tokens is None: g_recv_tokens = torch.zeros( @@ -573,14 +573,24 @@ def backward(ctx, g_recv_tokens, g_recv_topk_weights, _g_token_counts): # type: g_recv_tokens = g_recv_tokens.contiguous() if not g_recv_topk_weights.is_contiguous(): g_recv_topk_weights = g_recv_topk_weights.contiguous() - if ctx.grad_tokens_buf is not None: - # Zero-copy: narrow the persistent symm-mem slot to this call's leading dim. - grad_tokens = ctx.grad_tokens_buf.narrow(0, 0, ctx.tokens_T_flat) - else: - grad_tokens = torch.empty( - ctx.tokens_T_flat, ctx.hidden_dim, dtype=ctx.tokens_dtype, device=device - ) - grad_topk_weights = ctx.grad_topk_weights_buf.narrow(0, 0, ctx.topk_T_flat) + if ctx.zero_copy: + if g_recv_tokens.data_ptr() != ctx.dispatch_symm_ptr: + raise RuntimeError( + "ep_dispatch bwd: zero-copy mode requires g_recv_tokens to alias " + "buffer.dispatch_symm_buf (write MLP_bwd's grad into that slot; " + "no implicit copy)." + ) + if g_recv_topk_weights.data_ptr() != ctx.dispatch_w_symm_ptr: + raise RuntimeError( + "ep_dispatch bwd: zero-copy mode requires g_recv_topk_weights to alias " + "buffer.dispatch_w_symm_buf (no implicit copy)." + ) + grad_tokens = torch.empty( + ctx.tokens_T_flat, ctx.hidden_dim, dtype=ctx.tokens_dtype, device=device + ) + grad_topk_weights = torch.empty( + ctx.topk_T_flat, ctx.top_k, dtype=torch.float32, device=device + ) torch.ops.transformer_engine_ep.dispatch_bwd( ctx.handle_mem, g_recv_tokens, @@ -588,58 +598,54 @@ def backward(ctx, g_recv_tokens, g_recv_topk_weights, _g_token_counts): # type: grad_tokens, grad_topk_weights, ) - # Reshape back to the original input shape so autograd's grad slot matches. - grad_tokens_out = grad_tokens.view(ctx.tokens_shape) - grad_topk_weights_out = grad_topk_weights.view(ctx.topk_weights_shape) return ( None, # handle_mem None, # top_k None, # alignment + None, # zero_copy None, # recv_tokens None, # recv_topk_weights None, # token_counts - None, # grad_tokens_buf - None, # grad_topk_weights_buf None, # topk_idx - grad_tokens_out, - grad_topk_weights_out, + grad_tokens.view(ctx.tokens_shape), + grad_topk_weights.view(ctx.topk_weights_shape), ) class _EpCombine(torch.autograd.Function): """Autograd-aware combine. Zero-copy mode requires expert_out to alias - buffer.combine_in (no implicit staging), and that storage is reused as - the grad slot in bwd. Non-zero-copy mode reads expert_out directly and - allocates the bwd grad slot fresh. Caller pre-applies topk weighting. + combine_symm_buf (no implicit staging), and that storage is reused as the + bwd grad slot. Non-zero-copy mode reads expert_out directly and allocates + the bwd grad slot fresh. Caller pre-applies topk weighting. """ @staticmethod def forward( # type: ignore[override] ctx, handle_mem: torch.Tensor, - combine_in: Optional[torch.Tensor], + combine_symm_buf: Optional[torch.Tensor], num_local_tokens: int, hidden_dim: int, zero_copy: bool, expert_out: torch.Tensor, ): - """Combine fwd; zero-copy requires expert_out to alias combine_in.""" + """Combine fwd; zero-copy requires expert_out to alias combine_symm_buf.""" if zero_copy: - if combine_in is None: + if combine_symm_buf is None: raise RuntimeError( - "ep_combine: zero-copy mode requires buffer.combine_in to be allocated." + "ep_combine: zero-copy mode requires buffer.combine_symm_buf to be allocated." ) - if combine_in.data_ptr() != expert_out.data_ptr(): + if combine_symm_buf.data_ptr() != expert_out.data_ptr(): raise RuntimeError( "ep_combine: zero-copy mode requires expert_out to alias " - "buffer.combine_in (write expert outputs directly into that slot; " + "buffer.combine_symm_buf (write expert outputs directly into that slot; " "no implicit copy)." ) device = expert_out.device result = torch.empty(num_local_tokens, hidden_dim, dtype=expert_out.dtype, device=device) torch.ops.transformer_engine_ep.combine(handle_mem, expert_out, result) ctx.handle_mem = handle_mem - ctx.combine_in = combine_in # None in non-zero-copy; reused as grad slot otherwise + ctx.combine_symm_buf = combine_symm_buf # reused as grad slot in zero-copy ctx.zero_copy = zero_copy ctx.recv_capacity = expert_out.shape[0] ctx.hidden_dim = expert_out.shape[-1] @@ -648,11 +654,11 @@ def forward( # type: ignore[override] @staticmethod def backward(ctx, g_result): # type: ignore[override] - """Combine bwd; writes into combine_in in zero-copy or a fresh slot otherwise.""" + """Combine bwd; writes into combine_symm_buf in zero-copy or a fresh slot otherwise.""" if not g_result.is_contiguous(): g_result = g_result.contiguous() if ctx.zero_copy: - grad_combine_in = ctx.combine_in + grad_combine_in = ctx.combine_symm_buf else: grad_combine_in = torch.empty( ctx.recv_capacity, @@ -663,7 +669,7 @@ def backward(ctx, g_result): # type: ignore[override] torch.ops.transformer_engine_ep.combine_bwd(ctx.handle_mem, g_result, grad_combine_in) return ( None, # handle_mem - None, # combine_in + None, # combine_symm_buf None, # num_local_tokens None, # hidden_dim None, # zero_copy @@ -696,8 +702,8 @@ def ep_dispatch( """ _require_bf16("tokens", tokens) if buffer.zero_copy: - recv_tokens = buffer.recv_tokens - recv_topk_weights = buffer.recv_topk_weights + recv_tokens = buffer.dispatch_symm_buf + recv_topk_weights = buffer.dispatch_w_symm_buf else: recv_tokens = torch.empty( buffer.recv_capacity_per_rank, @@ -712,11 +718,10 @@ def ep_dispatch( buffer.handle_mem, buffer.top_k, buffer.alignment, + buffer.zero_copy, recv_tokens, recv_topk_weights, buffer.token_counts, - buffer.grad_tokens, - buffer.grad_topk_weights, topk_idx, tokens, topk_weights, @@ -731,7 +736,7 @@ def ep_combine( ): """Combine expert outputs back to the source rank, with autograd. Caller pre-applies topk weighting. Zero-copy mode requires expert_out to alias - buffer.combine_in (write expert outputs into that slot directly). + buffer.combine_symm_buf (write expert outputs into that slot directly). Result shape is (num_local_tokens, buffer.hidden_dim); defaults to buffer.max_tokens_per_rank rows. @@ -741,7 +746,7 @@ def ep_combine( num_local_tokens = buffer.max_tokens_per_rank return _EpCombine.apply( buffer.handle_mem, - buffer.combine_in, + buffer.combine_symm_buf, num_local_tokens, buffer.hidden_dim, buffer.zero_copy, From 6c7c7b3652a498c009a48468d60a9d3900c31b8c Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 9 Jun 2026 21:41:33 -0700 Subject: [PATCH 12/18] EP PyTorch: warn that ep_bootstrap(zero_copy=True) is experimental Signed-off-by: Phuong Nguyen --- transformer_engine/pytorch/ep.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index ec99e1250a..6fc69df92b 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -6,6 +6,7 @@ from __future__ import annotations import atexit +import warnings from typing import Optional import torch @@ -129,6 +130,12 @@ def ep_bootstrap( if ep_group.size() < 2: raise ValueError(f"ep_bootstrap requires ep_group.size() >= 2 (got {ep_group.size()}).") _check_nccl_runtime_version() + if zero_copy: + warnings.warn( + "ep_bootstrap(zero_copy=True) is experimental; the symm-mem IO path " + "and its alias contracts on EpBuffer slots are subject to change.", + stacklevel=2, + ) # Materialize the PG's NCCL comm before borrowing its raw handle. dist.barrier(group=ep_group, device_ids=[torch.cuda.current_device()]) From b3fb50d62eaf1d4b5190f9c943fd8fe4578fefd2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 11 Jun 2026 00:23:34 +0000 Subject: [PATCH 13/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/ep.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index 6fc69df92b..9789b587d1 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -323,11 +323,7 @@ def _check(t: torch.Tensor, name: str, shape: tuple, dtype: torch.dtype) -> torc inst.handle_mem = torch.empty(int(size_bytes), dtype=torch.uint8, device=device) if inst.zero_copy: - if ( - dispatch_symm_buf is None - or dispatch_w_symm_buf is None - or combine_symm_buf is None - ): + if dispatch_symm_buf is None or dispatch_w_symm_buf is None or combine_symm_buf is None: raise ValueError( "EpBuffer.from_external: zero-copy mode requires dispatch_symm_buf, " "dispatch_w_symm_buf, and combine_symm_buf (all symm-mem-backed)." From 50dcbdd0997300136a0e676d0d2e85d53cdcf94a Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 10 Jun 2026 17:30:04 -0700 Subject: [PATCH 14/18] EP PyTorch: wire test_ep.py into L1 distributed QA suite Signed-off-by: Phuong Nguyen --- qa/L1_pytorch_distributed_unittest/test.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 7eb34a62e4..50a51353d1 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -50,6 +50,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_use python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_newton_schulz.xml $TE_PATH/tests/pytorch/distributed/test_newton_schulz.py || test_fail "test_newton_schulz.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_ep.xml $TE_PATH/tests/pytorch/distributed/test_ep.py || test_fail "test_ep.py" # debug tests From 61b1b3a3172c1e1b738284c1eba6486f089c5515 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 10 Jun 2026 17:36:19 -0700 Subject: [PATCH 15/18] EP PyTorch: validate contiguity of dispatch/combine inputs and topk_weights fp32 dtype Signed-off-by: Phuong Nguyen --- transformer_engine/pytorch/csrc/extensions/ep.cpp | 3 +++ transformer_engine/pytorch/ep.py | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/transformer_engine/pytorch/csrc/extensions/ep.cpp b/transformer_engine/pytorch/csrc/extensions/ep.cpp index 8624605edd..e340b41acf 100644 --- a/transformer_engine/pytorch/csrc/extensions/ep.cpp +++ b/transformer_engine/pytorch/csrc/extensions/ep.cpp @@ -183,6 +183,8 @@ void ep_dispatch(at::Tensor handle_mem, at::Tensor topk_idx, at::Tensor tokens, NVTE_CHECK(topk_weights.dim() >= 2, "topk_weights must be at least 2D [..., top_k]"); NVTE_CHECK(recv_tokens.dim() >= 2, "recv_tokens must be at least 2D [..., recv_pr, H]"); check_topk_idx_int64(topk_idx); + NVTE_CHECK(tokens.is_contiguous(), "tokens must be contiguous"); + NVTE_CHECK(topk_weights.is_contiguous(), "topk_weights must be contiguous"); const size_t H = static_cast(tokens.size(-1)); const size_t T_flat = tokens.numel() / H; @@ -225,6 +227,7 @@ void ep_combine(at::Tensor handle_mem, at::Tensor expert_out, at::Tensor result) auto stream = at::cuda::getCurrentCUDAStream().stream(); NVTE_CHECK(expert_out.dim() >= 2, "expert_out must be at least 2D [..., recv_pr, H]"); NVTE_CHECK(result.dim() >= 2, "result must be at least 2D [..., H]"); + NVTE_CHECK(expert_out.is_contiguous(), "expert_out must be contiguous"); const size_t H = static_cast(expert_out.size(-1)); const size_t recv_pr = expert_out.numel() / H; diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index 9789b587d1..89951bc67c 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -704,6 +704,11 @@ def ep_dispatch( slots; otherwise they are freshly allocated. token_counts is non-diff. """ _require_bf16("tokens", tokens) + if topk_weights.dtype is not torch.float32: + raise TypeError( + f"topk_weights must be float32; got dtype={topk_weights.dtype}. " + "Cast with topk_weights.float() before calling." + ) if buffer.zero_copy: recv_tokens = buffer.dispatch_symm_buf recv_topk_weights = buffer.dispatch_w_symm_buf From 6fcbee5899b20d1c8313508eaf1819a65927b09f Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 11 Jun 2026 09:15:45 -0700 Subject: [PATCH 16/18] EP PyTorch: check topk_idx/topk_weights token count matches tokens in ep_dispatch Signed-off-by: Phuong Nguyen --- transformer_engine/pytorch/csrc/extensions/ep.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/transformer_engine/pytorch/csrc/extensions/ep.cpp b/transformer_engine/pytorch/csrc/extensions/ep.cpp index e340b41acf..a7c15a1140 100644 --- a/transformer_engine/pytorch/csrc/extensions/ep.cpp +++ b/transformer_engine/pytorch/csrc/extensions/ep.cpp @@ -193,6 +193,10 @@ void ep_dispatch(at::Tensor handle_mem, at::Tensor topk_idx, at::Tensor tokens, NVTE_CHECK(static_cast(topk_weights.size(-1)) == topk_n, "topk_weights last dim must equal topk_idx last dim"); + NVTE_CHECK(static_cast(topk_idx.numel()) == T_flat * topk_n, + "topk_idx token count must equal tokens token count"); + NVTE_CHECK(static_cast(topk_weights.numel()) == T_flat * topk_n, + "topk_weights token count must equal tokens token count"); NVTE_CHECK(static_cast(recv_topk_weights.numel()) == recv_pr, "recv_topk_weights total size must equal recv_tokens recv_pr"); NVTE_CHECK(recv_tokens.scalar_type() == tokens.scalar_type(), "recv_tokens dtype (", From 33b747ac523ee3399c82c7b6f5463a2e0e6a130b Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 11 Jun 2026 23:55:09 -0700 Subject: [PATCH 17/18] EP PyTorch: move symm-mem allocation out of EpBuffer; make ep_dispatch/ep_combine accept caller-supplied output buffers with C++ symm-mem checks under zero-copy Signed-off-by: Phuong Nguyen --- tests/pytorch/distributed/run_ep.py | 42 ++- .../pytorch/csrc/extensions/ep.cpp | 30 +++ transformer_engine/pytorch/ep.py | 243 ++++-------------- 3 files changed, 120 insertions(+), 195 deletions(-) diff --git a/tests/pytorch/distributed/run_ep.py b/tests/pytorch/distributed/run_ep.py index 0e29ca7eae..b09071a57c 100644 --- a/tests/pytorch/distributed/run_ep.py +++ b/tests/pytorch/distributed/run_ep.py @@ -173,8 +173,11 @@ def test_dispatch_fwd_bwd(self): buf = self._make_buffer() topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) tokens_p = tokens.detach().clone().requires_grad_(True) - recv_t, _recv_w, _tc = ep_dispatch(buf, tokens_p, topk_idx, w) - loss = 0.5 * (recv_t.float() ** 2).sum() + recv_t, recv_w, _tc = ep_dispatch(buf, tokens_p, topk_idx, w) + # Pull recv_w into the loss with a zero scale so both dispatch outputs + # contribute a (possibly-zero) grad — backward respects user-supplied + # grad inputs and won't fabricate Nones into zeros. + loss = 0.5 * (recv_t.float() ** 2).sum() + 0.0 * recv_w.float().sum() loss.backward() torch.cuda.synchronize() torch.testing.assert_close( @@ -293,6 +296,41 @@ def bwd(k): rtol=5e-2, ) + # Caller-supplied output buffers (autograd) + + def test_dispatch_caller_recv_buffers_autograd(self): + """ep_dispatch with caller-supplied recv buffers; fwd+bwd matches default-alloc grads.""" + buf = self._make_buffer() + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + recv_tokens, recv_w, _ = self._make_raw_recv() + tokens_p = tokens.detach().clone().requires_grad_(True) + rt, rw, _tc = ep_dispatch( + buf, tokens_p, topk_idx, w, recv_tokens=recv_tokens, recv_topk_weights=recv_w + ) + self.assertEqual(rt.data_ptr(), recv_tokens.data_ptr()) + self.assertEqual(rw.data_ptr(), recv_w.data_ptr()) + (0.5 * (rt.float() ** 2).sum() + 0.0 * rw.float().sum()).backward() + torch.cuda.synchronize() + torch.testing.assert_close( + tokens_p.grad.float(), tokens.float() * float(TOP_K), atol=5e-2, rtol=5e-2 + ) + + def test_combine_grad_expert_out_autograd(self): + """ep_combine with caller-supplied grad_expert_out; bwd writes into that slot.""" + buf = self._make_buffer() + topk_idx, tokens, w = _make_identity_inputs(self.cfg.rank, self.cfg.ep_size) + tokens_p = tokens.detach().clone().requires_grad_(True) + recv_t, recv_w, _ = ep_dispatch(buf, tokens_p, topk_idx, w) + eo = self._weighted(recv_t, recv_w) + grad_eo = torch.empty_like(eo) + gp = grad_eo.data_ptr() + out = ep_combine(buf, eo, grad_expert_out=grad_eo) + (0.5 * (out.float() ** 2).sum()).backward() + torch.cuda.synchronize() + self.assertEqual(grad_eo.data_ptr(), gp) + torch.testing.assert_close(out.float(), tokens.float(), atol=5e-2, rtol=5e-2) + torch.testing.assert_close(tokens_p.grad.float(), tokens.float(), atol=5e-2, rtol=5e-2) + # Input validation def test_topk_int32_raises_clear_error(self): diff --git a/transformer_engine/pytorch/csrc/extensions/ep.cpp b/transformer_engine/pytorch/csrc/extensions/ep.cpp index a7c15a1140..9953bc3993 100644 --- a/transformer_engine/pytorch/csrc/extensions/ep.cpp +++ b/transformer_engine/pytorch/csrc/extensions/ep.cpp @@ -81,6 +81,29 @@ NVTECommWindow maybe_make_window(const at::Tensor& t) { #endif } +// When zero-copy is enabled, the named tensor must be symm-mem-backed on the +// EP group. Throws a clear error otherwise. No-op when zero-copy is off or +// symm-mem support isn't compiled in. Mirrors maybe_make_window's resolution +// path but turns the "not symm-mem" outcome into a hard error. +void check_symm_mem_required(const at::Tensor& t, const char* name) { +#ifdef NCCL_HAS_SYMMEM_SUPPORT + if (!g_zero_copy_enabled.load(std::memory_order_relaxed)) return; + NVTE_CHECK(!g_ep_group_name.empty(), + "Zero-copy is enabled but EP group name is unset; call ep_initialize first."); + c10::intrusive_ptr sm; + try { + sm = c10d::symmetric_memory::rendezvous(t, g_ep_group_name); + } catch (const std::exception&) { + sm = nullptr; + } + NVTE_CHECK(sm != nullptr, "ep zero-copy: ", name, + " must be symm-mem-backed on the EP group (allocate via symm_mem_alloc)."); +#else + (void)t; + (void)name; +#endif +} + // The backend only accepts int64 topk_idx. The PyTorch wrapper enforces this // at the boundary so the per-step ops don't need an upcast workspace. void check_topk_idx_int64(at::Tensor topk_idx) { @@ -202,6 +225,8 @@ void ep_dispatch(at::Tensor handle_mem, at::Tensor topk_idx, at::Tensor tokens, NVTE_CHECK(recv_tokens.scalar_type() == tokens.scalar_type(), "recv_tokens dtype (", c10::toString(recv_tokens.scalar_type()), ") must match tokens dtype (", c10::toString(tokens.scalar_type()), ")"); + check_symm_mem_required(recv_tokens, "recv_tokens"); + check_symm_mem_required(recv_topk_weights, "recv_topk_weights"); auto tok_dtype = GetTransformerEngineDType(tokens.scalar_type()); auto handle_mem_te = makeTransformerEngineTensor( @@ -241,6 +266,7 @@ void ep_combine(at::Tensor handle_mem, at::Tensor expert_out, at::Tensor result) NVTE_CHECK(result.scalar_type() == expert_out.scalar_type(), "result dtype (", c10::toString(result.scalar_type()), ") must match expert_out dtype (", c10::toString(expert_out.scalar_type()), ")"); + check_symm_mem_required(expert_out, "expert_out"); auto eo_dtype = GetTransformerEngineDType(expert_out.scalar_type()); auto handle_mem_te = makeTransformerEngineTensor( @@ -275,6 +301,8 @@ void ep_dispatch_bwd(at::Tensor handle_mem, at::Tensor grad, at::Tensor g_recv_t NVTE_CHECK(grad_tokens.scalar_type() == grad.scalar_type(), "grad_tokens dtype (", c10::toString(grad_tokens.scalar_type()), ") must match grad dtype (", c10::toString(grad.scalar_type()), ")"); + check_symm_mem_required(grad, "grad (dispatch_bwd input)"); + check_symm_mem_required(g_recv_topk_weights, "g_recv_topk_weights"); auto g_dtype = GetTransformerEngineDType(grad.scalar_type()); auto handle_mem_te = makeTransformerEngineTensor( @@ -306,6 +334,8 @@ void ep_combine_bwd(at::Tensor handle_mem, at::Tensor grad, at::Tensor grad_expe NVTE_CHECK(grad_expert_out.scalar_type() == grad.scalar_type(), "grad_expert_out dtype (", c10::toString(grad_expert_out.scalar_type()), ") must match grad dtype (", c10::toString(grad.scalar_type()), ")"); + check_symm_mem_required(grad, "grad (combine_bwd input)"); + check_symm_mem_required(grad_expert_out, "grad_expert_out"); auto g_dtype = GetTransformerEngineDType(grad.scalar_type()); auto handle_mem_te = makeTransformerEngineTensor( diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index 89951bc67c..537978f593 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -179,40 +179,22 @@ def ep_finalize() -> None: class EpBuffer: - """Per-microbatch EP layer state: routing handle + persistent payload slots. - - Owns the per-call ``handle_mem`` routing scratch and the payload buffers - consumed by :func:`ep_dispatch` / :func:`ep_combine`. Allocate one - EpBuffer per concurrently-in-flight call on the layer (one per PP-1F1B - microbatch); sharing across overlapping calls clobbers tensors the - earlier bwd still reads. Call ``record_stream`` from streams other than - the allocation stream. - - Cross-rank payload slots are symm-mem-backed when ``ep_bootstrap`` was - called with ``zero_copy=True`` (requires ``ep_group``); otherwise plain - HBM. + """Per-microbatch EP layer state holding handle_mem and token_counts. + Cross-rank payload buffers are caller-supplied to ep_dispatch and + ep_combine; allocate via symm_mem_alloc in zero-copy mode. + Use one EpBuffer per concurrently-in-flight call (e.g. per PP-1F1B microbatch). """ __slots__ = ( - # routing "handle_mem", "top_k", "alignment", - # layer config "max_tokens_per_rank", "recv_capacity_per_rank", "hidden_dim", "num_local_experts", "payload_dtype", "device", - # Symm-mem slots (zero-copy only). Each is reused across fwd and bwd: - # dispatch_symm_buf: fwd out (recv_tokens) / bwd in (g_recv_tokens) - # dispatch_w_symm_buf: fwd out (recv_topk_w) / bwd in (g_recv_topk_w) - # combine_symm_buf: fwd in (expert_out) / bwd out (g_expert_out) - "dispatch_symm_buf", - "dispatch_w_symm_buf", - "combine_symm_buf", - # Per-rank scratch (always HBM). "token_counts", "zero_copy", ) @@ -225,7 +207,6 @@ def __init__( hidden_dim: int, num_local_experts: int, alignment: int = 0, - ep_group: Optional[dist.ProcessGroup] = None, payload_dtype: torch.dtype = torch.bfloat16, device: Optional[torch.device] = None, ) -> None: @@ -242,30 +223,10 @@ def __init__( self.num_local_experts = int(num_local_experts) self.payload_dtype = payload_dtype self.device = device + self.zero_copy = bool(tex.ep_get_zero_copy()) size_bytes = tex.ep_handle_mem_size(self.top_k, self.alignment) self.handle_mem = torch.empty(int(size_bytes), dtype=torch.uint8, device=device) - - recv_shape = (self.recv_capacity_per_rank, self.hidden_dim) - zero_copy = bool(tex.ep_get_zero_copy()) - self.zero_copy = zero_copy - if zero_copy: - if ep_group is None: - raise ValueError("EpBuffer requires ep_group when ep_bootstrap(zero_copy=True).") - self.dispatch_symm_buf = symm_mem_alloc( - recv_shape, payload_dtype, ep_group, device=device - ) - self.dispatch_w_symm_buf = symm_mem_alloc( - (self.recv_capacity_per_rank,), torch.float32, ep_group, device=device - ) - self.combine_symm_buf = symm_mem_alloc( - recv_shape, payload_dtype, ep_group, device=device - ) - else: - self.dispatch_symm_buf = None - self.dispatch_w_symm_buf = None - self.combine_symm_buf = None - # token_counts is local-only routing scratch; always plain HBM. self.token_counts = torch.empty(self.num_local_experts, dtype=torch.int32, device=device) @classmethod @@ -277,37 +238,19 @@ def from_external( hidden_dim: int, num_local_experts: int, *, - dispatch_symm_buf: Optional[torch.Tensor] = None, - dispatch_w_symm_buf: Optional[torch.Tensor] = None, - combine_symm_buf: Optional[torch.Tensor] = None, token_counts: Optional[torch.Tensor] = None, alignment: int = 0, payload_dtype: torch.dtype = torch.bfloat16, device: Optional[torch.device] = None, ) -> "EpBuffer": - """Construct from caller-allocated buffers. - - In zero-copy mode dispatch_symm_buf, dispatch_w_symm_buf, and - combine_symm_buf must all be supplied and symm-mem-backed; in - non-zero-copy mode they must all be None (ops allocate per call). - handle_mem is always allocated fresh. - """ + """Construct from a caller-allocated token_counts; handle_mem is always fresh.""" if device is None: device = torch.device("cuda", torch.cuda.current_device()) alignment = int(alignment) if alignment > 1 and (alignment & (alignment - 1)) != 0: raise ValueError(f"alignment must be 0, 1, or a power of two (got {alignment}).") - recv_shape = (recv_capacity_per_rank, hidden_dim) - recv_w_shape = (recv_capacity_per_rank,) counts_shape = (num_local_experts,) - def _check(t: torch.Tensor, name: str, shape: tuple, dtype: torch.dtype) -> torch.Tensor: - if tuple(t.shape) != shape: - raise ValueError(f"{name} shape {tuple(t.shape)} != expected {shape}") - if t.dtype != dtype: - raise ValueError(f"{name} dtype {t.dtype} != expected {dtype}") - return t - inst = cls.__new__(cls) inst.top_k = int(top_k) inst.alignment = alignment @@ -322,53 +265,22 @@ def _check(t: torch.Tensor, name: str, shape: tuple, dtype: torch.dtype) -> torc size_bytes = tex.ep_handle_mem_size(inst.top_k, inst.alignment) inst.handle_mem = torch.empty(int(size_bytes), dtype=torch.uint8, device=device) - if inst.zero_copy: - if dispatch_symm_buf is None or dispatch_w_symm_buf is None or combine_symm_buf is None: + if token_counts is not None: + if tuple(token_counts.shape) != counts_shape: raise ValueError( - "EpBuffer.from_external: zero-copy mode requires dispatch_symm_buf, " - "dispatch_w_symm_buf, and combine_symm_buf (all symm-mem-backed)." + f"token_counts shape {tuple(token_counts.shape)} != expected {counts_shape}" ) - inst.dispatch_symm_buf = _check( - dispatch_symm_buf, "dispatch_symm_buf", recv_shape, payload_dtype - ) - inst.dispatch_w_symm_buf = _check( - dispatch_w_symm_buf, "dispatch_w_symm_buf", recv_w_shape, torch.float32 - ) - inst.combine_symm_buf = _check( - combine_symm_buf, "combine_symm_buf", recv_shape, payload_dtype - ) + if token_counts.dtype != torch.int32: + raise ValueError(f"token_counts dtype {token_counts.dtype} != expected int32") + inst.token_counts = token_counts else: - if ( - dispatch_symm_buf is not None - or dispatch_w_symm_buf is not None - or combine_symm_buf is not None - ): - raise ValueError( - "EpBuffer.from_external: dispatch_symm_buf / dispatch_w_symm_buf / " - "combine_symm_buf are only used in zero-copy mode." - ) - inst.dispatch_symm_buf = None - inst.dispatch_w_symm_buf = None - inst.combine_symm_buf = None - inst.token_counts = ( - _check(token_counts, "token_counts", counts_shape, torch.int32) - if token_counts is not None - else torch.empty(counts_shape, dtype=torch.int32, device=device) - ) + inst.token_counts = torch.empty(counts_shape, dtype=torch.int32, device=device) return inst def record_stream(self, stream: torch.cuda.Stream) -> None: - """Record stream as a user of all owned tensors so the caching allocator - defers reclaim until stream has caught up.""" - for t in ( - self.handle_mem, - self.dispatch_symm_buf, - self.dispatch_w_symm_buf, - self.combine_symm_buf, - self.token_counts, - ): - if t is not None: - t.record_stream(stream) + """Defer caching-allocator reclaim of owned tensors until stream catches up.""" + self.handle_mem.record_stream(stream) + self.token_counts.record_stream(stream) # torch.library custom ops (so they don't graph-break under torch.compile) @@ -510,11 +422,7 @@ def _ep_combine_raw(buffer: "EpBuffer", expert_out: torch.Tensor, result: torch. class _EpDispatch(torch.autograd.Function): - """Autograd-aware prepare + dispatch. Fwd produces recv_tokens (alias of - dispatch_symm_buf in zero-copy, fresh otherwise). Zero-copy bwd requires - the incoming grads to alias dispatch_symm_buf / dispatch_w_symm_buf - (no implicit staging). Fwd/bwd share handle_mem; do not re-run ep_prepare. - """ + """Autograd prepare+dispatch; bwd uses user-supplied grad inputs as-is.""" @staticmethod def forward( # type: ignore[override] @@ -522,7 +430,6 @@ def forward( # type: ignore[override] handle_mem: torch.Tensor, top_k: int, alignment: int, - zero_copy: bool, recv_tokens: torch.Tensor, recv_topk_weights: torch.Tensor, token_counts: torch.Tensor, @@ -530,7 +437,7 @@ def forward( # type: ignore[override] tokens: torch.Tensor, topk_weights: torch.Tensor, ): - """Prepare + dispatch; saves shapes for the bwd pass.""" + """Prepare + dispatch fwd.""" torch.ops.transformer_engine_ep.prepare( handle_mem, top_k, topk_idx, token_counts, alignment ) @@ -543,12 +450,6 @@ def forward( # type: ignore[override] recv_topk_weights, ) ctx.handle_mem = handle_mem - ctx.zero_copy = zero_copy - # Stash the symm-mem slot pointers so bwd can enforce alias of the - # grad inputs. In non-zero-copy mode the slots are fresh per call; - # no enforcement is meaningful, so leave the pointers as None. - ctx.dispatch_symm_ptr = recv_tokens.data_ptr() if zero_copy else None - ctx.dispatch_w_symm_ptr = recv_topk_weights.data_ptr() if zero_copy else None ctx.tokens_shape = tokens.shape ctx.tokens_dtype = tokens.dtype ctx.topk_weights_shape = topk_weights.shape @@ -564,30 +465,8 @@ def forward( # type: ignore[override] @staticmethod def backward(ctx, g_recv_tokens, g_recv_topk_weights, _g_token_counts): # type: ignore[override] - """Dispatch bwd; in zero-copy the grad inputs must alias the symm-mem slots.""" + """Dispatch bwd; uses user-supplied grad inputs as-is.""" device = ctx.handle_mem.device - if g_recv_tokens is None: - g_recv_tokens = torch.zeros( - ctx.recv_capacity, ctx.hidden_dim, dtype=ctx.tokens_dtype, device=device - ) - if g_recv_topk_weights is None: - g_recv_topk_weights = torch.zeros(ctx.recv_capacity, dtype=torch.float32, device=device) - if not g_recv_tokens.is_contiguous(): - g_recv_tokens = g_recv_tokens.contiguous() - if not g_recv_topk_weights.is_contiguous(): - g_recv_topk_weights = g_recv_topk_weights.contiguous() - if ctx.zero_copy: - if g_recv_tokens.data_ptr() != ctx.dispatch_symm_ptr: - raise RuntimeError( - "ep_dispatch bwd: zero-copy mode requires g_recv_tokens to alias " - "buffer.dispatch_symm_buf (write MLP_bwd's grad into that slot; " - "no implicit copy)." - ) - if g_recv_topk_weights.data_ptr() != ctx.dispatch_w_symm_ptr: - raise RuntimeError( - "ep_dispatch bwd: zero-copy mode requires g_recv_topk_weights to alias " - "buffer.dispatch_w_symm_buf (no implicit copy)." - ) grad_tokens = torch.empty( ctx.tokens_T_flat, ctx.hidden_dim, dtype=ctx.tokens_dtype, device=device ) @@ -605,7 +484,6 @@ def backward(ctx, g_recv_tokens, g_recv_topk_weights, _g_token_counts): # type: None, # handle_mem None, # top_k None, # alignment - None, # zero_copy None, # recv_tokens None, # recv_topk_weights None, # token_counts @@ -616,66 +494,36 @@ def backward(ctx, g_recv_tokens, g_recv_topk_weights, _g_token_counts): # type: class _EpCombine(torch.autograd.Function): - """Autograd-aware combine. Zero-copy mode requires expert_out to alias - combine_symm_buf (no implicit staging), and that storage is reused as the - bwd grad slot. Non-zero-copy mode reads expert_out directly and allocates - the bwd grad slot fresh. Caller pre-applies topk weighting. - """ + """Autograd combine; bwd writes into grad_expert_out, or expert_out's storage if None.""" @staticmethod def forward( # type: ignore[override] ctx, handle_mem: torch.Tensor, - combine_symm_buf: Optional[torch.Tensor], num_local_tokens: int, hidden_dim: int, - zero_copy: bool, + grad_expert_out: Optional[torch.Tensor], expert_out: torch.Tensor, ): - """Combine fwd; zero-copy requires expert_out to alias combine_symm_buf.""" - if zero_copy: - if combine_symm_buf is None: - raise RuntimeError( - "ep_combine: zero-copy mode requires buffer.combine_symm_buf to be allocated." - ) - if combine_symm_buf.data_ptr() != expert_out.data_ptr(): - raise RuntimeError( - "ep_combine: zero-copy mode requires expert_out to alias " - "buffer.combine_symm_buf (write expert outputs directly into that slot; " - "no implicit copy)." - ) + """Combine fwd; stashes grad_expert_out (or expert_out) as the bwd output slot.""" device = expert_out.device result = torch.empty(num_local_tokens, hidden_dim, dtype=expert_out.dtype, device=device) torch.ops.transformer_engine_ep.combine(handle_mem, expert_out, result) ctx.handle_mem = handle_mem - ctx.combine_symm_buf = combine_symm_buf # reused as grad slot in zero-copy - ctx.zero_copy = zero_copy - ctx.recv_capacity = expert_out.shape[0] - ctx.hidden_dim = expert_out.shape[-1] - ctx.expert_out_dtype = expert_out.dtype + ctx.grad_expert_out = grad_expert_out if grad_expert_out is not None else expert_out return result @staticmethod def backward(ctx, g_result): # type: ignore[override] - """Combine bwd; writes into combine_symm_buf in zero-copy or a fresh slot otherwise.""" if not g_result.is_contiguous(): g_result = g_result.contiguous() - if ctx.zero_copy: - grad_combine_in = ctx.combine_symm_buf - else: - grad_combine_in = torch.empty( - ctx.recv_capacity, - ctx.hidden_dim, - dtype=ctx.expert_out_dtype, - device=ctx.handle_mem.device, - ) + grad_combine_in = ctx.grad_expert_out torch.ops.transformer_engine_ep.combine_bwd(ctx.handle_mem, g_result, grad_combine_in) return ( None, # handle_mem - None, # combine_symm_buf None, # num_local_tokens None, # hidden_dim - None, # zero_copy + None, # grad_expert_out grad_combine_in, ) @@ -696,12 +544,15 @@ def ep_dispatch( tokens: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, + *, + recv_tokens: Optional[torch.Tensor] = None, + recv_topk_weights: Optional[torch.Tensor] = None, ): - """Run prepare + dispatch with autograd. topk_idx must be int64. + """Prepare + dispatch with autograd. topk_idx must be int64. - Returns (recv_tokens, recv_topk_weights, token_counts). In zero-copy mode - recv_tokens / recv_topk_weights alias the buffer's persistent symm-mem - slots; otherwise they are freshly allocated. token_counts is non-diff. + recv_tokens / recv_topk_weights are used as-is if supplied, else allocated. + Zero-copy mode requires both to be supplied and symm-mem-backed. + Returns (recv_tokens, recv_topk_weights, token_counts); token_counts is non-diff. """ _require_bf16("tokens", tokens) if topk_weights.dtype is not torch.float32: @@ -709,16 +560,19 @@ def ep_dispatch( f"topk_weights must be float32; got dtype={topk_weights.dtype}. " "Cast with topk_weights.float() before calling." ) - if buffer.zero_copy: - recv_tokens = buffer.dispatch_symm_buf - recv_topk_weights = buffer.dispatch_w_symm_buf - else: + if buffer.zero_copy and (recv_tokens is None or recv_topk_weights is None): + raise ValueError( + "ep_dispatch: zero-copy mode requires caller-supplied recv_tokens and " + "recv_topk_weights (allocate via symm_mem_alloc)." + ) + if recv_tokens is None: recv_tokens = torch.empty( buffer.recv_capacity_per_rank, buffer.hidden_dim, dtype=buffer.payload_dtype, device=buffer.device, ) + if recv_topk_weights is None: recv_topk_weights = torch.empty( buffer.recv_capacity_per_rank, dtype=torch.float32, device=buffer.device ) @@ -726,7 +580,6 @@ def ep_dispatch( buffer.handle_mem, buffer.top_k, buffer.alignment, - buffer.zero_copy, recv_tokens, recv_topk_weights, buffer.token_counts, @@ -741,22 +594,26 @@ def ep_combine( expert_out: torch.Tensor, *, num_local_tokens: Optional[int] = None, + grad_expert_out: Optional[torch.Tensor] = None, ): - """Combine expert outputs back to the source rank, with autograd. Caller - pre-applies topk weighting. Zero-copy mode requires expert_out to alias - buffer.combine_symm_buf (write expert outputs into that slot directly). + """Combine with autograd; caller pre-applies topk weighting. - Result shape is (num_local_tokens, buffer.hidden_dim); defaults to - buffer.max_tokens_per_rank rows. + grad_expert_out is the slot the bwd writes into; if None, expert_out's storage is reused. + Zero-copy mode requires both expert_out and grad_expert_out to be symm-mem-backed. + Result shape is (num_local_tokens, buffer.hidden_dim); defaults to buffer.max_tokens_per_rank rows. """ _require_bf16("expert_out", expert_out) + if buffer.zero_copy and grad_expert_out is None: + raise ValueError( + "ep_combine: zero-copy mode requires caller-supplied grad_expert_out " + "(allocate via symm_mem_alloc)." + ) if num_local_tokens is None: num_local_tokens = buffer.max_tokens_per_rank return _EpCombine.apply( buffer.handle_mem, - buffer.combine_symm_buf, num_local_tokens, buffer.hidden_dim, - buffer.zero_copy, + grad_expert_out, expert_out, ) From 3e9e1cf27f12a9ea92414451d925e106d0b136f5 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Fri, 12 Jun 2026 00:00:04 -0700 Subject: [PATCH 18/18] EP PyTorch: enforce contiguous caller-supplied EP buffers in C++ and normalize bwd grad layout in Python Signed-off-by: Phuong Nguyen --- transformer_engine/pytorch/csrc/extensions/ep.cpp | 6 ++++++ transformer_engine/pytorch/ep.py | 4 +++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/csrc/extensions/ep.cpp b/transformer_engine/pytorch/csrc/extensions/ep.cpp index 9953bc3993..67e18cd70a 100644 --- a/transformer_engine/pytorch/csrc/extensions/ep.cpp +++ b/transformer_engine/pytorch/csrc/extensions/ep.cpp @@ -208,6 +208,8 @@ void ep_dispatch(at::Tensor handle_mem, at::Tensor topk_idx, at::Tensor tokens, check_topk_idx_int64(topk_idx); NVTE_CHECK(tokens.is_contiguous(), "tokens must be contiguous"); NVTE_CHECK(topk_weights.is_contiguous(), "topk_weights must be contiguous"); + NVTE_CHECK(recv_tokens.is_contiguous(), "recv_tokens must be contiguous"); + NVTE_CHECK(recv_topk_weights.is_contiguous(), "recv_topk_weights must be contiguous"); const size_t H = static_cast(tokens.size(-1)); const size_t T_flat = tokens.numel() / H; @@ -286,6 +288,8 @@ void ep_dispatch_bwd(at::Tensor handle_mem, at::Tensor grad, at::Tensor g_recv_t NVTE_CHECK(grad.dim() >= 2, "grad must be at least 2D [..., recv_pr, H]"); NVTE_CHECK(grad_tokens.dim() >= 2, "grad_tokens must be at least 2D [..., H]"); NVTE_CHECK(grad_topk_weights.dim() >= 2, "grad_topk_weights must be at least 2D [..., top_k]"); + NVTE_CHECK(grad.is_contiguous(), "grad must be contiguous"); + NVTE_CHECK(g_recv_topk_weights.is_contiguous(), "g_recv_topk_weights must be contiguous"); const size_t H = static_cast(grad.size(-1)); const size_t recv_pr = grad.numel() / H; @@ -325,6 +329,8 @@ void ep_combine_bwd(at::Tensor handle_mem, at::Tensor grad, at::Tensor grad_expe auto stream = at::cuda::getCurrentCUDAStream().stream(); NVTE_CHECK(grad.dim() >= 2, "grad must be at least 2D [..., H]"); NVTE_CHECK(grad_expert_out.dim() >= 2, "grad_expert_out must be at least 2D [..., recv_pr, H]"); + NVTE_CHECK(grad.is_contiguous(), "grad must be contiguous"); + NVTE_CHECK(grad_expert_out.is_contiguous(), "grad_expert_out must be contiguous"); const size_t H = static_cast(grad.size(-1)); const size_t T_flat = grad.numel() / H; diff --git a/transformer_engine/pytorch/ep.py b/transformer_engine/pytorch/ep.py index 537978f593..0caf3b8c8b 100644 --- a/transformer_engine/pytorch/ep.py +++ b/transformer_engine/pytorch/ep.py @@ -465,8 +465,10 @@ def forward( # type: ignore[override] @staticmethod def backward(ctx, g_recv_tokens, g_recv_topk_weights, _g_token_counts): # type: ignore[override] - """Dispatch bwd; uses user-supplied grad inputs as-is.""" + """Dispatch bwd; normalizes grad-input layout, otherwise passes through.""" device = ctx.handle_mem.device + g_recv_tokens = g_recv_tokens.contiguous() + g_recv_topk_weights = g_recv_topk_weights.contiguous() grad_tokens = torch.empty( ctx.tokens_T_flat, ctx.hidden_dim, dtype=ctx.tokens_dtype, device=device )