Single-launch CUTLASS grouped GEMM for per-tensor NVFP4#3134
Conversation
Replace the per-expert multi-stream cuBLASLt loop in the per-tensor NVFP4
grouped path with one CUTLASS ptr-array grouped launch on SM100 (Blackwell).
Common:
- Add nvfp4_cutlass_grouped_gemm.{cuh,cu}: a single-launch per-tensor NVFP4
grouped kernel. Covers BF16 output (fprop/dgrad, overwrite), FP32 output
(wgrad, with optional in-place accumulate for Megatron wgrad fusion), and
optional fused per-group bias (fprop). Per-tensor scaling collapses the
second-level scale to one fp32 alpha per group, applied via the epilogue's
per-group alpha_ptr_array, Arch-gated on CUTLASS_ARCH_MMA_SM100_SUPPORTED.
- Wire it into nvte_multi_tensor_gemm behind the opt-in env
NVTE_NVFP4_CUTLASS_GROUPED_GEMM. M/N/K must be %128; empty (0-token)
experts schedule 0 tiles and no longer force a multi-stream fallback.
Anything unsupported falls through to the existing cuBLAS path, so default
behavior is unchanged.
- Add bench-only entry points nvte_nvfp4_grouped_per_tensor_compute_alpha /
nvte_nvfp4_grouped_per_tensor_gemm so a benchmark can precompute alpha
outside the timed region and time only the grouped GEMM launch.
PyTorch:
- Expose the two bench-only bindings (tex.nvfp4_grouped_per_tensor_compute_alpha
and tex.nvfp4_grouped_per_tensor_gemm). Not used by the production dispatch.
- Extend test_grouped_linear.py with NVFP4 cutlass-vs-multistream parity tests:
GEMM-level (uniform + uneven 128-aligned splits), empty groups, and
end-to-end GroupedLinear fwd+bwd (bias, fuse_wgrad_accumulation).
- Add a GEMM-level cutlass-vs-multistream comparison to
benchmark_grouped_linear.py (--compare-nvfp4-grouped-gemm): a DISPATCH row
(both backends via the shared dispatch) and a fair PURE row (operands
pre-swizzled, alpha precomputed; times only the GEMM).
Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds a single-launch CUTLASS ptr-array grouped GEMM for the per-tensor NVFP4 path on SM100 (Blackwell), replacing the per-expert multi-stream cuBLASLt loop. The new path is opt-in via
Confidence Score: 3/5Safe to merge with caveats: the default path is completely unchanged, but the opt-in CUTLASS path has unprotected static state that could corrupt buffers under concurrent CPU-thread access. The new CUTLASS kernel and dispatch logic are well-structured and the fallback to cuBLAS is robust. However, the static persistent buffers in persistent_buffer() and persistent_host_buffer() are read-modify-written without any mutex — any concurrent CPU-thread access produces a double-free and dangling pointer. Separately, hw_info.device_id=0 and cached_sm_count() always target physical device 0, diverging from the existing Hopper CUTLASS path and could cause incorrect tile scheduling in multi-GPU single-process configurations. transformer_engine/common/gemm/nvfp4_cutlass_grouped_gemm.cu — the persistent buffer helpers and hardcoded device_id=0; transformer_engine/common/gemm/cublaslt_gemm.cu — the per-call alpha_buf allocation. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[nvte_multi_tensor_gemm] --> B{is_blackwell AND env=1?}
B -- No --> G[existing cuBLAS path]
B -- Yes --> C{eligible: NVFP4 per-tensor, shapes pct128, no gelu}
C -- No --> G
C -- Yes --> D[cudaMallocAsync alpha_buf, per-group alpha compute, skip empty experts]
D --> E{num_nonempty == 0?}
E -- Yes --> F[cudaFreeAsync + return]
E -- No --> H[nvfp4_cutlass run_grouped_per_tensor_gemm]
H --> I{fp32_output? has_bias?}
I -- BF16 plus bias --> J[run_impl bf16 kHasBias=true]
I -- FP32 --> K[run_impl float kHasBias=false]
I -- BF16 no bias --> L[run_impl bf16 kHasBias=false]
J & K & L --> M[Pack metadata to persistent device buffer, single H2D copy]
M --> N[CUTLASS GemmUniversalAdapter: can_implement, initialize, run]
N --> O[cudaFreeAsync alpha_buf + return]
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
A[nvte_multi_tensor_gemm] --> B{is_blackwell AND env=1?}
B -- No --> G[existing cuBLAS path]
B -- Yes --> C{eligible: NVFP4 per-tensor, shapes pct128, no gelu}
C -- No --> G
C -- Yes --> D[cudaMallocAsync alpha_buf, per-group alpha compute, skip empty experts]
D --> E{num_nonempty == 0?}
E -- Yes --> F[cudaFreeAsync + return]
E -- No --> H[nvfp4_cutlass run_grouped_per_tensor_gemm]
H --> I{fp32_output? has_bias?}
I -- BF16 plus bias --> J[run_impl bf16 kHasBias=true]
I -- FP32 --> K[run_impl float kHasBias=false]
I -- BF16 no bias --> L[run_impl bf16 kHasBias=false]
J & K & L --> M[Pack metadata to persistent device buffer, single H2D copy]
M --> N[CUTLASS GemmUniversalAdapter: can_implement, initialize, run]
N --> O[cudaFreeAsync alpha_buf + return]
|
| static void *persistent_buffer(size_t bytes, cudaStream_t stream, int which) { | ||
| static void *bufs[2] = {nullptr, nullptr}; | ||
| static size_t caps[2] = {0, 0}; | ||
| if (bytes > caps[which]) { | ||
| if (bufs[which] != nullptr) { | ||
| NVTE_CHECK_CUDA(cudaFreeAsync(bufs[which], stream)); | ||
| } | ||
| const size_t newcap = bytes + bytes / 2; // slack to avoid frequent regrows | ||
| NVTE_CHECK_CUDA(cudaMallocAsync(&bufs[which], newcap, stream)); | ||
| caps[which] = newcap; | ||
| } | ||
| return bufs[which]; | ||
| } | ||
|
|
||
| // Reusable pageable host staging buffer for the single batched H2D copy of all | ||
| // per-group metadata. Pageable (not pinned) is intentional: cudaMemcpyAsync | ||
| // stages pageable source into a driver buffer before returning, so the host | ||
| // buffer can be safely overwritten by the next launch without extra sync. | ||
| static void *persistent_host_buffer(size_t bytes) { | ||
| static std::vector<uint8_t> buf; | ||
| if (buf.size() < bytes) { | ||
| buf.resize(bytes + bytes / 2); | ||
| } | ||
| return buf.data(); | ||
| } |
There was a problem hiding this comment.
Non-thread-safe static persistent buffers
persistent_buffer() and persistent_host_buffer() both mutate static local variables (bufs[], caps[], buf) without any synchronization. The comment says "Assumes grouped GEMMs are issued serially on one stream (the TE norm)" — but CUDA stream ordering is GPU-side; two CPU threads can simultaneously reach the if (bytes > caps[which]) branch, both see a stale capacity, and both issue cudaFreeAsync on the same pointer followed by competing cudaMallocAsync writes back into bufs[which]. This produces a double-free and a dangling pointer even under the single-stream assumption. Multi-threaded callers (e.g. torch.compile with threading, or TP frameworks that invoke GEMM from multiple CPU threads) would be vulnerable. Adding a std::mutex around the resize block (or using the same mutex that guards the adjacent cuBLAS handle manager) would fix this.
| cutlass::KernelHardwareInfo hw_info; | ||
| hw_info.device_id = 0; | ||
| hw_info.sm_count = cached_sm_count(); |
There was a problem hiding this comment.
hw_info.device_id is hardcoded to 0 regardless of which GPU the current process is using. The existing Hopper CUTLASS grouped-GEMM path (cutlass_grouped_gemm.cuh) correctly propagates current_device down from the dispatcher (see cublaslt_gemm.cu line 1277), and cached_sm_count() also always queries device 0. In a multi-GPU process where the caller has selected a non-zero device, the SM count used for tile scheduling will be wrong. Please use transformer_engine::cuda::current_device() (or equivalent) to resolve the correct device, matching the existing pattern.
| cutlass::KernelHardwareInfo hw_info; | |
| hw_info.device_id = 0; | |
| hw_info.sm_count = cached_sm_count(); | |
| const int cur_device = transformer_engine::cuda::current_device(); | |
| cutlass::KernelHardwareInfo hw_info; | |
| hw_info.device_id = cur_device; | |
| hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(cur_device); |
Description
The per-tensor NVFP4 grouped GEMM currently runs as a per-expert loop of cuBLASLt GEMMs across multiple CUDA streams. This PR adds a single-launch CUTLASS ptr-array grouped kernel for the same path on SM100 (Blackwell): all experts are issued in one kernel launch. The new path is opt-in behind the env var
NVTE_NVFP4_CUTLASS_GROUPED_GEMM; the default behavior is unchanged. Anything the CUTLASS path does not support (non-SM100, non-NVFP4, shapes not %128, etc.) transparently falls through to the existing cuBLASLt path.Performance improvement
SM100,
python benchmarks/linear/benchmark_grouped_linear.py --compare-nvfp4-grouped-gemm. Times in ms;DISPATCH = full
nvte_multi_tensor_gemmpath (per-call scale swizzle + per-expert alpha compute included for both backends),PURE = kernel-vs-kernel (operands pre-swizzled, alpha precomputed; times only the GEMM).
TN (fprop)
NN (dgrad)
NT (wgrad)
Summary (cutlass vs multi-stream cuBLASLt):
The PURE rows show the raw single-launch CUTLASS kernel is ~3–4x faster than the per-expert multi-stream cuBLASLt loop for fprop/dgrad and ~2–2.7x for wgrad.
Type of change
Changes
Common (CUDA/C++)
gemm/nvfp4_cutlass_grouped_gemm.{cuh,cu}: a single-launch per-tensor NVFP4 grouped kernel. Per-tensor scaling collapses the second-level scale to one fp32alphaper group, applied via the epilogue's per-groupalpha_ptr_array(no per-row/col vector-broadcast EVT needed).Covers:
accumulate(Megatron wgrad fusion)CUTLASS_ARCH_MMA_SM100_SUPPORTED(error stub otherwise).nvte_multi_tensor_gemmbehindNVTE_NVFP4_CUTLASS_GROUPED_GEMM;unsupported cases fall back to cuBLAS.
nvte_nvfp4_grouped_per_tensor_compute_alpha/nvte_nvfp4_grouped_per_tensor_gemmto allow timing the pure GEMM with alphaprecomputed outside the timed region.
.cu(CMakeLists.txt).PyTorch
tex(
nvfp4_grouped_per_tensor_compute_alpha,nvfp4_grouped_per_tensor_gemm).Not used by the production dispatch.
tests/pytorch/test_grouped_linear.py: add NVFP4 cutlass-vs-multistreamparity tests — GEMM-level (uniform + uneven 128-aligned splits), empty
groups, and end-to-end
GroupedLinearfwd+bwd (bias, fuse_wgrad_accumulation).benchmarks/linear/benchmark_grouped_linear.py: add--compare-nvfp4-grouped-gemm— a DISPATCH row and a fair PURE row.Checklist: