Skip to content

[PyTorch] Expert Parallelism: PyTorch wrapper + autograd ops with symm-mem zero-copy#3035

Open
phu0ngng wants to merge 18 commits into
NVIDIA:mainfrom
phu0ngng:phuong/ep-3-pytorch-on-commwindow
Open

[PyTorch] Expert Parallelism: PyTorch wrapper + autograd ops with symm-mem zero-copy#3035
phu0ngng wants to merge 18 commits into
NVIDIA:mainfrom
phu0ngng:phuong/ep-3-pytorch-on-commwindow

Conversation

@phu0ngng

@phu0ngng phu0ngng commented May 22, 2026

Copy link
Copy Markdown
Collaborator

Summary

Second PR in the TE Expert Parallelism (EP) series. Adds the PyTorch binding on top of the common C API (#3034): exposes EP dispatch/combine as torch.library custom ops with autograd, and plumbs NCCL symmetric-memory windows through for the zero-copy path.

Payload tensors allocated via te.pytorch.ep.symm_mem_alloc take the one-sided zero-copy path when ep_bootstrap(zero_copy=True); anything else falls back to staged-copy, so the API stays drop-in compatible with any allocator.

Implementation

Public Python API (transformer_engine/pytorch/ep.py)

    EpBuffer, ep_bootstrap, ep_finalize,                                                                                                                                                                                                                                                        ep_dispatch, ep_combine,
    symm_mem_alloc,                                                                                                                                                                                                                                                                         )
  • ep_bootstrap / ep_finalize - one-time per-process init/teardown. Borrows the NCCL comm from ep_group via ProcessGroupNCCL._comm_ptr() (no separate ncclUniqueId bootstrap). ep_finalize is optional - an atexit handler covers normal shutdown; call it explicitly before dist.destroy_process_group(). Requires ep_group.size() >= 2.
  • symm_mem_alloc(shape, dtype, ep_group) - per-rank tensor backed by NCCL symmetric memory, rendezvoused on ep_group.
  • EpBuffer - per-layer state: routing handle + persistent payload slots (recv_tokens, combine_in, grad buffers). One per concurrently-in-flight call (e.g. PP-1F1B microbatch). Symm-mem-backed when zero_copy=True.
  • ep_dispatch / ep_combine - autograd-aware per-step ops, registered as torch.library.custom_op with correct mutates_args, so they compose with torch.compile fullgraph and CUDA graphs.
    Current payload dtype is restricted to bfloat16; FP8 quantize/dequantize stays outside the EP boundary.

C++ bindings (transformer_engine/pytorch/csrc/extensions/ep.cpp)

  • POD-only pybind boundary (primitives + pybind11::object for dtype) - no c10d ABI on the boundary. - maybe_make_window() resolves each payload tensor to an NVTECommWindow via c10d::symmetric_memory::rendezvous; non-symm-mem tensors return kNoWindow and the backend picks staged-copy automatically.
  • Zero-copy toggle captured at ep_initialize and forwarded into NVTEEpGroupConfig.zero_copy.

Build

build_tools/pytorch.py propagates -DNVTE_WITH_NCCL_EP (gated on NVTE_BUILD_WITH_NCCL_EP=1, default on) and -DUSE_NCCL so PyTorch's symm-mem feature macros are visible. When NCCL EP is off, ep.cpp no-ops behind the #ifdef.

Testing

  • tests/pytorch/distributed/run_ep.py - 8-test suite: prepare correctness, raw dispatch/combine identity round-trip, dispatch fwd+bwd VJP, full fwd+bwd round-trip, multi-iter bit-stability, CUDA graph capture, PP-1F1B 3-buffer interleave, int64 topk_idx validation. Launcher run_test_ep.sh auto-detects GPUs (skips with <4). Pytest driver: tests/pytorch/distributed/test_ep.py.
  • Example: examples/pytorch/ep/ep_moe.py - minimal end-to-end MoE fwd+bwd driver with --check against an analytical reference.
  • Bench: examples/pytorch/ep/bench/ep_bench.py - times raw + autograd dispatch/combine, optional --cuda-graph capture and --kineto/--nsys profiling.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@phu0ngng phu0ngng requested review from ksivaman and ptrendx as code owners May 22, 2026 02:54
@greptile-apps

greptile-apps Bot commented May 22, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

Adds the PyTorch binding layer for Expert Parallelism: torch.library custom ops with full autograd support, NCCL symmetric-memory zero-copy plumbing, a per-microbatch EpBuffer state object, and an 8-test distributed suite covering round-trips, VJP correctness, CUDA graph capture, and PP-1F1B interleave.

  • transformer_engine/pytorch/ep.py — public API (ep_bootstrap, ep_finalize, EpBuffer, ep_dispatch, ep_combine, symm_mem_alloc) backed by _EpDispatch / _EpCombine autograd Functions; dtype and contiguity guards added at the Python boundary.
  • transformer_engine/pytorch/csrc/extensions/ep.cpp — POD-only pybind layer; maybe_make_window resolves symm-mem windows for zero-copy, check_symm_mem_required enforces symm-mem backing for zero-copy payload slots, and per-step ops include cross-dimension shape checks and contiguity guards.
  • Tests / examples / buildrun_ep.py covers forward, backward, multi-iter stability, CUDA graph, and interleave; build system gates the extension on NVTE_BUILD_WITH_NCCL_EP=1.

Confidence Score: 4/5

Safe to merge for the default zero_copy=False path, which is what all tests exercise; the experimental zero_copy=True path contains a backward crash that must be resolved before any training user enables it.

The zero_copy=False code path (the production path) is well-tested and the validation guards added in this PR are thorough. However, when zero_copy=True is set, both ep_dispatch_bwd and ep_combine_bwd call check_symm_mem_required on gradient tensors that PyTorch's autograd engine allocates — these are never symm-mem-backed — causing an unconditional throw on every backward pass. No test exercises this path, so the breakage would surprise any user who opts into the documented experimental feature and then runs training.

transformer_engine/pytorch/csrc/extensions/ep.cpp — specifically ep_dispatch_bwd (lines 308-309) and ep_combine_bwd (lines 343-344) where check_symm_mem_required is applied to autograd-allocated gradient inputs.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ep.py Main Python EP API — bootstrap, EpBuffer, autograd Function wrappers, and high-level dispatch/combine ops. float32 topk_weights validation and token-count cross-checks were addressed. All zero_copy=False paths look correct; the experimental zero_copy=True path is broken in backward (see ep.cpp finding).
transformer_engine/pytorch/csrc/extensions/ep.cpp C++ pybind layer for EP. Forward ops are well-validated (contiguity, dtype, shape cross-checks added). Two check_symm_mem_required calls in ep_dispatch_bwd and ep_combine_bwd unconditionally fail for autograd-allocated gradient tensors when zero_copy=True, making zero-copy autograd unusable.
transformer_engine/pytorch/csrc/extensions.h Added EP function declarations — ep_initialize, ep_finalize, ep_get_zero_copy, ep_handle_mem_size, and the per-step op signatures. Clean.
build_tools/pytorch.py Propagates -DNVTE_WITH_NCCL_EP and -DUSE_NCCL when NVTE_BUILD_WITH_NCCL_EP=1. Correct opt-in gating, no-ops when disabled.
tests/pytorch/distributed/run_ep.py Multi-process EP test suite covering prepare, dispatch/combine round-trips, autograd VJPs, multi-iter stability, CUDA graph capture, PP-1F1B interleave, and int64 validation. All tests use ZERO_COPY=False; zero-copy backward path is not tested.
tests/pytorch/distributed/test_ep.py Pytest driver that spawns run_test_ep.sh; skips when <4 GPUs. Clean wrapper.
tests/pytorch/distributed/run_test_ep.sh torchrun launcher with setsid/timeout hang detection and stdout log scanning. Clean.
examples/pytorch/ep/ep_moe.py Minimal end-to-end MoE fwd+bwd driver with analytical reference check. Clean example code.
examples/pytorch/ep/bench/ep_bench.py Performance benchmark for raw and autograd dispatch/combine with optional CUDA graph and Kineto profiling. Clean.
qa/L1_pytorch_distributed_unittest/test.sh Adds test_ep.py to the L1 distributed unittest suite. One-line addition, correct.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant User as User Code
    participant PyEP as ep.py (Python)
    participant AutoF as _EpDispatch/_EpCombine
    participant CustomOp as torch.library custom ops
    participant CppEP as ep.cpp (C++)
    participant NVTE as nvte_ep_* (C API)

    User->>PyEP: "ep_bootstrap(ep_group, zero_copy=False)"
    PyEP->>PyEP: barrier + _comm_ptr()
    PyEP->>CppEP: tex.ep_initialize(comm_ptr, group_name, ...)
    CppEP->>NVTE: nvte_ep_initialize(comm, cfg)

    User->>PyEP: ep_dispatch(buffer, tokens, topk_idx, topk_weights)
    PyEP->>AutoF: _EpDispatch.apply(handle_mem, ...)
    AutoF->>CustomOp: transformer_engine_ep::prepare(handle_mem, topk_idx, ...)
    CustomOp->>CppEP: tex.ep_prepare(...)
    CppEP->>NVTE: nvte_ep_prepare(handle_mem, topk_idx, token_counts, layer_cfg)
    AutoF->>CustomOp: transformer_engine_ep::dispatch(handle_mem, tokens, recv_tokens, ...)
    CustomOp->>CppEP: tex.ep_dispatch(...)
    CppEP->>CppEP: maybe_make_window(recv_tokens) → NVTECommWindow
    CppEP->>NVTE: nvte_ep_dispatch(handle_mem, tokens, tokens_win, recv_tokens, recv_win, ...)
    AutoF-->>User: recv_tokens, recv_topk_weights, token_counts

    User->>User: expert computation on recv_tokens
    User->>PyEP: ep_combine(buffer, expert_out)
    PyEP->>AutoF: _EpCombine.apply(handle_mem, expert_out, ...)
    AutoF->>CustomOp: transformer_engine_ep::combine(handle_mem, expert_out, result)
    CustomOp->>CppEP: tex.ep_combine(...)
    CppEP->>NVTE: nvte_ep_combine(handle_mem, expert_out, expert_out_win, result)
    AutoF-->>User: result (combined tokens)

    User->>User: loss.backward()
    AutoF->>CustomOp: transformer_engine_ep::combine_bwd(handle_mem, g_result, grad_expert_out)
    CustomOp->>CppEP: tex.ep_combine_bwd(...)
    CppEP->>NVTE: nvte_ep_combine_bwd(...)
    AutoF->>CustomOp: transformer_engine_ep::dispatch_bwd(handle_mem, g_recv_tokens, ...)
    CustomOp->>CppEP: tex.ep_dispatch_bwd(...)
    CppEP->>NVTE: nvte_ep_dispatch_bwd(...)
    AutoF-->>User: grad_tokens, grad_topk_weights
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant User as User Code
    participant PyEP as ep.py (Python)
    participant AutoF as _EpDispatch/_EpCombine
    participant CustomOp as torch.library custom ops
    participant CppEP as ep.cpp (C++)
    participant NVTE as nvte_ep_* (C API)

    User->>PyEP: "ep_bootstrap(ep_group, zero_copy=False)"
    PyEP->>PyEP: barrier + _comm_ptr()
    PyEP->>CppEP: tex.ep_initialize(comm_ptr, group_name, ...)
    CppEP->>NVTE: nvte_ep_initialize(comm, cfg)

    User->>PyEP: ep_dispatch(buffer, tokens, topk_idx, topk_weights)
    PyEP->>AutoF: _EpDispatch.apply(handle_mem, ...)
    AutoF->>CustomOp: transformer_engine_ep::prepare(handle_mem, topk_idx, ...)
    CustomOp->>CppEP: tex.ep_prepare(...)
    CppEP->>NVTE: nvte_ep_prepare(handle_mem, topk_idx, token_counts, layer_cfg)
    AutoF->>CustomOp: transformer_engine_ep::dispatch(handle_mem, tokens, recv_tokens, ...)
    CustomOp->>CppEP: tex.ep_dispatch(...)
    CppEP->>CppEP: maybe_make_window(recv_tokens) → NVTECommWindow
    CppEP->>NVTE: nvte_ep_dispatch(handle_mem, tokens, tokens_win, recv_tokens, recv_win, ...)
    AutoF-->>User: recv_tokens, recv_topk_weights, token_counts

    User->>User: expert computation on recv_tokens
    User->>PyEP: ep_combine(buffer, expert_out)
    PyEP->>AutoF: _EpCombine.apply(handle_mem, expert_out, ...)
    AutoF->>CustomOp: transformer_engine_ep::combine(handle_mem, expert_out, result)
    CustomOp->>CppEP: tex.ep_combine(...)
    CppEP->>NVTE: nvte_ep_combine(handle_mem, expert_out, expert_out_win, result)
    AutoF-->>User: result (combined tokens)

    User->>User: loss.backward()
    AutoF->>CustomOp: transformer_engine_ep::combine_bwd(handle_mem, g_result, grad_expert_out)
    CustomOp->>CppEP: tex.ep_combine_bwd(...)
    CppEP->>NVTE: nvte_ep_combine_bwd(...)
    AutoF->>CustomOp: transformer_engine_ep::dispatch_bwd(handle_mem, g_recv_tokens, ...)
    CustomOp->>CppEP: tex.ep_dispatch_bwd(...)
    CppEP->>NVTE: nvte_ep_dispatch_bwd(...)
    AutoF-->>User: grad_tokens, grad_topk_weights
Loading

Reviews (11): Last reviewed commit: "EP PyTorch: enforce contiguous caller-su..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/ep.py Outdated
Comment thread transformer_engine/pytorch/csrc/extensions/ep.cpp Outdated
Comment thread transformer_engine/pytorch/ep.py Outdated
Comment on lines +558 to +568
@contextlib.contextmanager
def _zero_copy_scope(enabled: bool):
"""Toggles whether per-step ops apply the symm-mem NCCL window annotation."""
if enabled:
yield
return
tex.ep_set_zero_copy(False)
try:
yield
finally:
tex.ep_set_zero_copy(True)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 _zero_copy_scope does not save/restore the previous flag value

When enabled=False, the manager unconditionally sets g_zero_copy_enabled=False on entry and g_zero_copy_enabled=True on exit. If two callers both use zero_copy=False concurrently (e.g., pipeline-parallel microbatches dispatched from separate Python threads) or if the context is nested, the inner scope's finally block prematurely re-enables zero-copy while the outer scope is still active. The outer scope's finally then sets True again, but between the inner finally and the outer finally the C++ layer sees True unexpectedly.

The fix is to capture the previous value before writing and restore it unconditionally: save old = tex.ep_get_zero_copy() (adding a corresponding getter), then tex.ep_set_zero_copy(old) in the finally block. At minimum, document the single-caller-at-a-time assumption prominently so pipeline-parallel users know to serialize.

Comment thread transformer_engine/common/ep/ep_backend.cpp Outdated
@phu0ngng phu0ngng marked this pull request as draft May 22, 2026 03:03
@phu0ngng phu0ngng force-pushed the phuong/ep-3-pytorch-on-commwindow branch 4 times, most recently from 540ef54 to bacae5f Compare May 24, 2026 00:06
Comment thread transformer_engine/pytorch/ep.py Outdated
device = expert_out.device
# Weight in payload dtype: single fused broadcast multiply into combine_in.
w = recv_topk_weights.unsqueeze(-1).to(expert_out.dtype)
torch.mul(expert_out, w, out=combine_in)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we need this?🤔
At the training scenario, the weight gets multiplied onto the activation between fc1 and fc2 (we also dispatch the weight at the same time as dispatching the tokens), or am I misunderstanding something here?

My understanding is that this multiplication is unnecessary. Furthermore, if it is removed, another problem becomes more prominent: how do we add symm buffer support for the combine input? This would require changes on the grouped GEMM side.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Second this. I saw unexpected kernel here and found this same problem. A potential solution is to provide a separate path when the weight is not provided. This means the weight multiplication is handled elsewhere, and in this case skip the multiplication here.

@phu0ngng phu0ngng May 26, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to learn that we can fuse the weight x to the activation. I will make this optional.

We will need to change the GG to return the symmetric memory buf.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. we need change the grouped gemm I think

ep_group: dist.ProcessGroup,
num_experts: int,
max_tokens_per_rank: int,
recv_capacity_per_rank: int,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When allocating the buffer, we need to allocate according to the worst case. There are two scenarios here:

  • The first is rank-major, where the memory footprint is max_tokens_per_rank × num_of_ranks. This generally stays below 10 GB, which is the primary memory overhead of typical EP setups and is acceptable.
  • The second is expert-major, where the memory footprint is max_tokens_per_rank × num_of_ranks × min(topk, num_of_experts). This could reach 40–50 GB, which is unacceptable.

If I understand this correctly, we must find a way to optimize the memory usage in the expert-major layout — or alternatively, we need to fall back to the rank-major layout + explicit permutation approach.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the rank-major, you still need to overallocate the output buffer of local permute as in expert-major. Right?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two types of buffers:

The first is the EP buffer, which serves as the destination for communication (NCCL EP is a push-based design), so it requires a relatively costly registration process. These are reused globally as static buffers as much as possible, so they are allocated based on the worst-case size. In HEP, the rank-major output buffer is an EP buffer, so we only need a rank-major worst-case-size buffer. I haven't studied NCCL EP in detail, but my understanding is that if our output is a symmetric buffer, we don't need a built-in static comm buffer inside NCCL EP — meaning recv_capacity_per_rank is not needed when the output buffer is a symm buffer. I think this is worth discussing and clarifying.

The second type is regular GPU memory, which can be managed by the caching allocator. In HEP, the output of the permute operation falls into this category — it can be dynamically allocated each iteration based on the scan result, with just one additional sync required. Additionally, in sync-free mode, the size of this buffer is specified by the user.

To summarize, we may need to confirm whether recv_capacity_per_rank requires building an expert-major worst-case-size buffer inside NCCL EP. If the output is a symm buffer, we theoretically don't need such a buffer. However, if it is necessary, then we cannot accept an expert-major worst-case-size buffer. I also observed in my draft PR that NCCL EP uses more memory.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,
It's correct that if the output buffer is a symmem, then we should not need to register the gigantic IPC/MC buffer in ep_group with the size based on recv_capacity_per_rank. Let's request NCCL EP to add an option to skip this buffer allocation.

However, I think we should still ask users to specify this recv_capacity_per_rank so that we can handle overflow policy in the metadata_preprocessing rather than delaying it to dispatch phase.

@Autumn1998 Autumn1998 May 28, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need an option to skip this internal buffer.
Also, are you thinking of using recv_capacity_per_rank to support the sync-free mechanism? That is, tokens exceeding the threshold get dropped, and then trigger the flipping of the overflow flag? I think this is not correct — we should not set it at buffer initialization, but instead pass it as a parameter before the preprocess step of each dispatch, because the threshold changes every iteration.
cc @nanz-nv plz correct me if I made mistakes

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because the threshold changes every iteration.

I'm curious to learn about this possibility. From my understanding, the output buffers need to have a static size for CUDA Graph replay, and so does the recv_capacity.

@Autumn1998 Autumn1998 May 29, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for each global batch, we recalculate a new output size, since each batch has its own CUDA graph — but I'm not 100% sure on this. You may want to confirm with @nanz-nv.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is something in between. With the current way of doing full-iteration cuda graph, ideally recv_capacity_per_rank should stay the same across training, but it can sometimes gets updated. So I'd treat it as something that may change but not frequently.

@timmoon10 timmoon10 self-requested a review June 1, 2026 17:34
@phu0ngng phu0ngng force-pushed the phuong/ep-3-pytorch-on-commwindow branch 4 times, most recently from 40d8011 to 2153492 Compare June 10, 2026 01:27
@phu0ngng phu0ngng marked this pull request as ready for review June 10, 2026 01:28
Comment thread transformer_engine/pytorch/csrc/extensions/ep.cpp
@phu0ngng phu0ngng force-pushed the phuong/ep-3-pytorch-on-commwindow branch from 9ec1aff to 7ce8d8b Compare June 10, 2026 03:20
Comment thread transformer_engine/pytorch/ep.py
@phu0ngng phu0ngng force-pushed the phuong/ep-3-pytorch-on-commwindow branch from b2ab069 to c8c54fd Compare June 11, 2026 00:22
Comment thread transformer_engine/pytorch/ep.py
Comment on lines +186 to +210

const size_t H = static_cast<size_t>(tokens.size(-1));
const size_t T_flat = tokens.numel() / H;
const size_t topk_n = static_cast<size_t>(topk_idx.size(-1));
const size_t recv_pr = recv_tokens.numel() / H;

NVTE_CHECK(static_cast<size_t>(topk_weights.size(-1)) == topk_n,
"topk_weights last dim must equal topk_idx last dim");
NVTE_CHECK(static_cast<size_t>(recv_topk_weights.numel()) == recv_pr,
"recv_topk_weights total size must equal recv_tokens recv_pr");
NVTE_CHECK(recv_tokens.scalar_type() == tokens.scalar_type(), "recv_tokens dtype (",
c10::toString(recv_tokens.scalar_type()), ") must match tokens dtype (",
c10::toString(tokens.scalar_type()), ")");

auto tok_dtype = GetTransformerEngineDType(tokens.scalar_type());
auto handle_mem_te = makeTransformerEngineTensor(
handle_mem.data_ptr(), Shape{static_cast<size_t>(handle_mem.numel())}, DType::kByte);
auto topk_idx_te =
makeTransformerEngineTensor(topk_idx.data_ptr(), Shape{T_flat, topk_n}, DType::kInt64);
auto tokens_te = makeTransformerEngineTensor(tokens.data_ptr(), Shape{T_flat, H}, tok_dtype);
auto topk_w_te =
makeTransformerEngineTensor(topk_weights.data_ptr(), Shape{T_flat, topk_n}, DType::kFloat32);
auto recv_tokens_te =
makeTransformerEngineTensor(recv_tokens.data_ptr(), Shape{recv_pr, H}, tok_dtype);
auto recv_topk_w_te =

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Token-count mismatch between tokens, topk_idx, and topk_weights goes unchecked

In ep_dispatch, T_flat is derived from tokens.numel() / H, but the TE tensor descriptors for topk_idx and topk_weights are assembled using that same T_flat without verifying those tensors actually contain T_flat rows. If a caller inadvertently passes a topk_idx or topk_weights from a differently-sized batch (e.g. a leftover buffer from a previous micro-batch with a different sequence length), makeTransformerEngineTensor silently builds a descriptor that claims more rows than the tensor holds, and the subsequent NCCL EP kernel performs an OOB GPU memory read. The mismatch can also arise across the two-step call sequence: ep_prepare computes its own T_flat from topk_idx.numel(), so if topk_idx is later swapped for a different-sized one before calling ep_dispatch, the routing table and the dispatch tensor descriptor silently disagree.

Adding NVTE_CHECK(topk_idx.numel() == T_flat * topk_n, ...) and NVTE_CHECK(topk_weights.numel() == T_flat * topk_n, ...) before the descriptor construction would surface this class of error immediately.

@phu0ngng phu0ngng force-pushed the phuong/ep-3-pytorch-on-commwindow branch from df732a5 to 67917a3 Compare June 11, 2026 16:16
@phu0ngng

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

@phu0ngng

Copy link
Copy Markdown
Collaborator Author

Pipeline #54455868 TE EP tests passed in L1_pytorch_distributed_unittest--B200_8GPU and L1_pytorch_distributed_unittest--H100_4GPU. There are other failures that are unrelated to TE EP.

@phu0ngng phu0ngng force-pushed the phuong/ep-3-pytorch-on-commwindow branch 2 times, most recently from 52bbf88 to d6c5745 Compare June 13, 2026 00:08
NVTEShape idx_shape = nvte_tensor_shape(topk_idx);
void* idx_data = nvte_tensor_data(topk_idx);
NVTE_CHECK(idx_data != nullptr, "topk_idx data must not be null");

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 max_token_bytes hardcodes bfloat16 but float32 payloads are not blocked

cfg.max_token_bytes is set to hidden_dim * sizeof(nv_bfloat16) at group creation time. NCCL EP uses this to size internal staging buffers. If a caller constructs an EpBuffer with explicitly float32 recv_tokens and passes float32 tokens, the C++ dtype-match check in ep_dispatch passes (both tensors are float32), but max_token_bytes is half of what NCCL EP needs — leading to OOB writes inside ncclEpDispatch.

The default auto-alloc path is protected because recv_tokens defaults to handle.payload_dtype (bfloat16), making the C++ dtype-match check catch float32 tokens. However, a user who explicitly allocates float32 recv_tokens (a documented option in EpBuffer.__init__) and passes float32 tokens would silently bypass all guards and hit NCCL EP with undersized buffer configuration. The Python boundary only rejects FP8 (_reject_fp8), not float32.

Consider adding a validation in ep_bootstrap / EpHandle.__init__ that enforces bfloat16 at the payload boundary, and add an assertion in ep_dispatch (C++) that tok_dtype == kNVTEBFloat16 to fail fast instead of silently corrupting memory.

@timmoon10 timmoon10 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overal LGTM

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: A filename like expert_parallel.py would be a bit more obvious (#3034 (comment)).

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ep_dispatch and ep_combine APIs are convenient and obvious, and it's especially nice how torch.compile support looks straightforward. If we ever want to fuse EP communication with grouped MLP compute, then we should also consider implementing a te.ops.BasicOperation for each, and then implementing a te.ops.FusedOperation for dispatch+FC1+act+FC2+combine (or some subset).

Comment thread transformer_engine/pytorch/ep.py Outdated
Comment on lines +465 to +467
ctx.handle_mem = handle_mem
ctx.handle_id = handle_id
ctx.grad_topk_weights_buf = grad_topk_weights_buf

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's recommended to store tensors with ctx.save_for_backward rather than assigning directly.

Suggested change
ctx.handle_mem = handle_mem
ctx.handle_id = handle_id
ctx.grad_topk_weights_buf = grad_topk_weights_buf
ctx.save_for_backward(handle_mem, grad_topk_weights_buf)
ctx.handle_id = handle_id

I don't think many of the concerns are relevant to us (these are persistent workspace buffers, so no grads or memory usage concerns), but there are some weird edge cases its supposed to avoid.

We'll also need to change the backward:

handle_mem, grad_topk_weights_buf = ctx.saved_tensors

We should also do a similar change in _EpCombine.

One final bug we might encounter is with CPU offloading, which automatically offloads saved tensors (see https://github.com/NVIDIA/TransformerEngine/pull/3035/changes#r3407248167).

Comment thread transformer_engine/pytorch/ep.py Outdated
Comment on lines +547 to +549
ctx.handle_mem = handle_mem
ctx.handle_id = handle_id
ctx.recv_tokens_grad = recv_tokens_grad

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid storing tensors directly in ctx:

Suggested change
ctx.handle_mem = handle_mem
ctx.handle_id = handle_id
ctx.recv_tokens_grad = recv_tokens_grad
ctx.save_for_backward(handle_mem, recv_tokens_grad)
ctx.handle_id = handle_id

Comment thread transformer_engine/pytorch/ep.py Outdated
self.grad_topk_weights = torch.empty(
(handle.max_tokens_per_rank, handle.top_k), dtype=torch.float32, device=device
)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we save these workspace buffers in autograd contexts, then they might get picked up by CPU offloading.

Suggested change
# Prevent buffers from participating in activation CPU offloading
mark_not_offload(
self.recv_tokens,
self.recv_tokens_grad,
self.recv_topk_weights,
self.token_counts,
self.grad_topk_weights,
)

See:

def mark_not_offload(*tensors: torch.Tensor):

phu0ngng and others added 18 commits June 17, 2026 08:33
…ia cfg.zero_copy

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
… a single buffer

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ents; drop unused ep_group kwargs

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
… _-prefixed stub args, autograd docstrings)

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…t_out to alias combine_in in zero-copy

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…rad_tokens/grad_topk_weights; alias-check bwd grads in zero-copy

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…eights fp32 dtype

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
… ep_dispatch

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…h/ep_combine accept caller-supplied output buffers with C++ symm-mem checks under zero-copy

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…normalize bwd grad layout in Python

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng phu0ngng force-pushed the phuong/ep-3-pytorch-on-commwindow branch from d6c5745 to 3e9e1cf Compare June 17, 2026 15:42
Comment on lines +308 to +309
check_symm_mem_required(grad, "grad (dispatch_bwd input)");
check_symm_mem_required(g_recv_topk_weights, "g_recv_topk_weights");

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Zero-copy backward always crashes — upstream grad tensors can never be symm-mem

ep_dispatch_bwd calls check_symm_mem_required on both grad (= g_recv_tokens) and g_recv_topk_weights, and ep_combine_bwd calls it on grad (= g_result). These tensors are allocated by PyTorch's autograd engine, not by the user via symm_mem_alloc, so they are never symm-mem-backed. When zero_copy=True, check_symm_mem_required is a hard error, meaning any training run that uses ep_dispatch or ep_combine with zero_copy=True will unconditionally throw in the backward pass: "ep zero-copy: grad (dispatch_bwd input) must be symm-mem-backed…". The existing autograd tests all set ZERO_COPY = False so this path is not exercised by the test suite. Either the backward path should fall back to staged-copy for non-symm-mem grad inputs (replacing check_symm_mem_required with maybe_make_window, matching the PR's stated "anything else falls back to staged-copy" contract), or EpBuffer needs symm-mem-allocated gradient slots and the autograd backward needs to route into them before calling these ops.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants