Skip to content

Single-launch CUTLASS grouped GEMM for per-tensor NVFP4#3134

Open
cael-ling wants to merge 2 commits into
NVIDIA:mainfrom
cael-ling:optimize/group-gemm
Open

Single-launch CUTLASS grouped GEMM for per-tensor NVFP4#3134
cael-ling wants to merge 2 commits into
NVIDIA:mainfrom
cael-ling:optimize/group-gemm

Conversation

@cael-ling

@cael-ling cael-ling commented Jun 17, 2026

Copy link
Copy Markdown
Contributor

Description

The per-tensor NVFP4 grouped GEMM currently runs as a per-expert loop of cuBLASLt GEMMs across multiple CUDA streams. This PR adds a single-launch CUTLASS ptr-array grouped kernel for the same path on SM100 (Blackwell): all experts are issued in one kernel launch. The new path is opt-in behind the env var NVTE_NVFP4_CUTLASS_GROUPED_GEMM; the default behavior is unchanged. Anything the CUTLASS path does not support (non-SM100, non-NVFP4, shapes not %128, etc.) transparently falls through to the existing cuBLASLt path.

Performance improvement

SM100, python benchmarks/linear/benchmark_grouped_linear.py --compare-nvfp4-grouped-gemm. Times in ms;
DISPATCH = full nvte_multi_tensor_gemm path (per-call scale swizzle + per-expert alpha compute included for both backends),
PURE = kernel-vs-kernel (operands pre-swizzled, alpha precomputed; times only the GEMM).

TN (fprop)

shape N K tok row cutlass multi-stream speedup
8e x 128 2048 2048 1024 dispatch 0.184 0.272 1.48x
8e x 128 2048 2048 1024 pure 0.069 0.267 3.86x
8e x 256 2048 2048 2048 dispatch 0.177 0.269 1.52x
8e x 256 2048 2048 2048 pure 0.072 0.269 3.72x
8e x 512 2048 2048 4096 dispatch 0.176 0.266 1.51x
8e x 512 2048 2048 4096 pure 0.079 0.268 3.37x
8e x imbal 2048 2048 2304 dispatch 0.183 0.270 1.47x
8e x imbal 2048 2048 2304 pure 0.073 0.270 3.68x
16e x 256 2048 2048 4096 dispatch 0.266 0.414 1.56x
16e x 256 2048 2048 4096 pure 0.117 0.415 3.53x
16e x imbal 4096 2048 4224 dispatch 0.280 0.422 1.51x
16e x imbal 4096 2048 4224 pure 0.134 0.419 3.12x
32e x 128 2048 2048 4096 dispatch 0.453 0.708 1.56x
32e x 128 2048 2048 4096 pure 0.201 0.702 3.50x
32e x imbal 2048 2048 7296 dispatch 0.472 0.723 1.53x
32e x imbal 2048 2048 7296 pure 0.213 0.723 3.40x

NN (dgrad)

shape N K tok row cutlass multi-stream speedup
8e x 128 2048 2048 1024 dispatch 0.190 0.278 1.47x
8e x 128 2048 2048 1024 pure 0.070 0.275 3.95x
8e x 256 2048 2048 2048 dispatch 0.185 0.278 1.50x
8e x 256 2048 2048 2048 pure 0.073 0.277 3.81x
8e x 512 2048 2048 4096 dispatch 0.186 0.279 1.50x
8e x 512 2048 2048 4096 pure 0.080 0.276 3.46x
8e x imbal 2048 2048 2304 dispatch 0.183 0.281 1.53x
8e x imbal 2048 2048 2304 pure 0.073 0.278 3.80x
16e x 256 2048 2048 4096 dispatch 0.281 0.433 1.54x
16e x 256 2048 2048 4096 pure 0.122 0.430 3.54x
16e x imbal 4096 2048 4224 dispatch 0.283 0.428 1.51x
16e x imbal 4096 2048 4224 pure 0.131 0.433 3.31x
32e x 128 2048 2048 4096 dispatch 0.462 0.718 1.55x
32e x 128 2048 2048 4096 pure 0.203 0.718 3.54x
32e x imbal 2048 2048 7296 dispatch 0.478 0.731 1.53x
32e x imbal 2048 2048 7296 pure 0.214 0.726 3.39x

NT (wgrad)

shape N K tok row cutlass multi-stream speedup
8e x 128 2048 2048 1024 dispatch 0.202 0.279 1.38x
8e x 128 2048 2048 1024 pure 0.104 0.279 2.68x
8e x 256 2048 2048 2048 dispatch 0.200 0.279 1.40x
8e x 256 2048 2048 2048 pure 0.103 0.279 2.71x
8e x 512 2048 2048 4096 dispatch 0.200 0.282 1.41x
8e x 512 2048 2048 4096 pure 0.105 0.279 2.66x
8e x imbal 2048 2048 2304 dispatch 0.201 0.280 1.39x
8e x imbal 2048 2048 2304 pure 0.104 0.278 2.67x
16e x 256 2048 2048 4096 dispatch 0.330 0.431 1.30x
16e x 256 2048 2048 4096 pure 0.184 0.426 2.32x
16e x imbal 4096 2048 4224 dispatch 0.381 0.430 1.13x
16e x imbal 4096 2048 4224 pure 0.237 0.427 1.80x
32e x 128 2048 2048 4096 dispatch 0.591 0.709 1.20x
32e x 128 2048 2048 4096 pure 0.343 0.698 2.04x
32e x imbal 2048 2048 7296 dispatch 0.592 0.708 1.20x
32e x imbal 2048 2048 7296 pure 0.347 0.701 2.02x

Summary (cutlass vs multi-stream cuBLASLt):

layout PURE (kernel-only) DISPATCH (full path)
TN (fprop) ~3.1–3.9x ~1.47–1.56x
NN (dgrad) ~3.3–4.0x ~1.47–1.56x
NT (wgrad) ~1.8–2.7x ~1.13–1.41x

The PURE rows show the raw single-launch CUTLASS kernel is ~3–4x faster than the per-expert multi-stream cuBLASLt loop for fprop/dgrad and ~2–2.7x for wgrad.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Common (CUDA/C++)

  • Add gemm/nvfp4_cutlass_grouped_gemm.{cuh,cu}: a single-launch per-tensor NVFP4 grouped kernel. Per-tensor scaling collapses the second-level scale to one fp32 alpha per group, applied via the epilogue's per-group alpha_ptr_array (no per-row/col vector-broadcast EVT needed).
    Covers:
    • BF16 output, overwrite (fprop / dgrad)
    • FP32 output (wgrad), with optional in-place accumulate (Megatron wgrad fusion)
    • optional fused per-group bias (fprop)
    • empty (0-token) experts
    • Arch-gated on CUTLASS_ARCH_MMA_SM100_SUPPORTED (error stub otherwise).
  • Wire it into nvte_multi_tensor_gemm behind NVTE_NVFP4_CUTLASS_GROUPED_GEMM;
    unsupported cases fall back to cuBLAS.
  • Add bench-only entry points nvte_nvfp4_grouped_per_tensor_compute_alpha /
    nvte_nvfp4_grouped_per_tensor_gemm to allow timing the pure GEMM with alpha
    precomputed outside the timed region.
  • Build the new .cu (CMakeLists.txt).

PyTorch

  • Expose the two bench-only bindings via tex
    (nvfp4_grouped_per_tensor_compute_alpha, nvfp4_grouped_per_tensor_gemm).
    Not used by the production dispatch.
  • tests/pytorch/test_grouped_linear.py: add NVFP4 cutlass-vs-multistream
    parity tests — GEMM-level (uniform + uneven 128-aligned splits), empty
    groups, and end-to-end GroupedLinear fwd+bwd (bias, fuse_wgrad_accumulation).
  • benchmarks/linear/benchmark_grouped_linear.py: add
    --compare-nvfp4-grouped-gemm — a DISPATCH row and a fair PURE row.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Replace the per-expert multi-stream cuBLASLt loop in the per-tensor NVFP4
grouped path with one CUTLASS ptr-array grouped launch on SM100 (Blackwell).

Common:
- Add nvfp4_cutlass_grouped_gemm.{cuh,cu}: a single-launch per-tensor NVFP4
  grouped kernel. Covers BF16 output (fprop/dgrad, overwrite), FP32 output
  (wgrad, with optional in-place accumulate for Megatron wgrad fusion), and
  optional fused per-group bias (fprop). Per-tensor scaling collapses the
  second-level scale to one fp32 alpha per group, applied via the epilogue's
  per-group alpha_ptr_array, Arch-gated on CUTLASS_ARCH_MMA_SM100_SUPPORTED.
- Wire it into nvte_multi_tensor_gemm behind the opt-in env
  NVTE_NVFP4_CUTLASS_GROUPED_GEMM. M/N/K must be %128; empty (0-token)
  experts schedule 0 tiles and no longer force a multi-stream fallback.
  Anything unsupported falls through to the existing cuBLAS path, so default
  behavior is unchanged.
- Add bench-only entry points nvte_nvfp4_grouped_per_tensor_compute_alpha /
  nvte_nvfp4_grouped_per_tensor_gemm so a benchmark can precompute alpha
  outside the timed region and time only the grouped GEMM launch.

PyTorch:
- Expose the two bench-only bindings (tex.nvfp4_grouped_per_tensor_compute_alpha
  and tex.nvfp4_grouped_per_tensor_gemm). Not used by the production dispatch.
- Extend test_grouped_linear.py with NVFP4 cutlass-vs-multistream parity tests:
  GEMM-level (uniform + uneven 128-aligned splits), empty groups, and
  end-to-end GroupedLinear fwd+bwd (bias, fuse_wgrad_accumulation).
- Add a GEMM-level cutlass-vs-multistream comparison to
  benchmark_grouped_linear.py (--compare-nvfp4-grouped-gemm): a DISPATCH row
  (both backends via the shared dispatch) and a fair PURE row (operands
  pre-swizzled, alpha precomputed; times only the GEMM).

Signed-off-by: Cael Ling <caell@nvidia.com>
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 17, 2026
@greptile-apps

greptile-apps Bot commented Jun 17, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds a single-launch CUTLASS ptr-array grouped GEMM for the per-tensor NVFP4 path on SM100 (Blackwell), replacing the per-expert multi-stream cuBLASLt loop. The new path is opt-in via NVTE_NVFP4_CUTLASS_GROUPED_GEMM and transparently falls back to cuBLAS when unsupported shapes or features are requested.

  • New nvfp4_cutlass_grouped_gemm.cu implements the SM100 CUTLASS kernel covering BF16/FP32 output, fused per-group bias (fprop), and in-place wgrad accumulation; uses a persistent device/host buffer scheme to batch all per-group metadata into a single H2D copy per launch.
  • Dispatch logic added to nvte_multi_tensor_gemm: an eligible() predicate gates on NVFP4 per-tensor scaling, M/N/K % 128, Blackwell arch, and absence of unsupported fusions; empty experts (M=0) are handled without vetoing the whole batch.
  • Two bench-only entry points exposed via Python bindings to enable fair pure-GEMM timing with alpha precomputed outside the timed region.

Confidence Score: 3/5

Safe to merge with caveats: the default path is completely unchanged, but the opt-in CUTLASS path has unprotected static state that could corrupt buffers under concurrent CPU-thread access.

The new CUTLASS kernel and dispatch logic are well-structured and the fallback to cuBLAS is robust. However, the static persistent buffers in persistent_buffer() and persistent_host_buffer() are read-modify-written without any mutex — any concurrent CPU-thread access produces a double-free and dangling pointer. Separately, hw_info.device_id=0 and cached_sm_count() always target physical device 0, diverging from the existing Hopper CUTLASS path and could cause incorrect tile scheduling in multi-GPU single-process configurations.

transformer_engine/common/gemm/nvfp4_cutlass_grouped_gemm.cu — the persistent buffer helpers and hardcoded device_id=0; transformer_engine/common/gemm/cublaslt_gemm.cu — the per-call alpha_buf allocation.

Important Files Changed

Filename Overview
transformer_engine/common/gemm/nvfp4_cutlass_grouped_gemm.cu New CUTLASS SM100 grouped GEMM kernel; thread-unsafe static persistent buffers and hardcoded device_id=0 need attention.
transformer_engine/common/gemm/cublaslt_gemm.cu Dispatch wiring for the new CUTLASS path is correct; per-call cudaMallocAsync for alpha_buf is a minor performance miss on the hot path.
transformer_engine/common/gemm/nvfp4_cutlass_grouped_gemm.cuh Clean public interface header; matches implementation correctly.
transformer_engine/common/include/transformer_engine/gemm.h Two new bench-only API declarations added with clear documentation; no issues.
transformer_engine/pytorch/csrc/extensions/gemm.cpp Python bindings for bench-only entry points correctly implemented with proper GIL handling.
tests/pytorch/test_grouped_linear.py Good test coverage: GEMM-level parity (uniform + uneven splits), empty groups, and end-to-end GroupedLinear fwd+bwd.
benchmarks/linear/benchmark_grouped_linear.py New --compare-nvfp4-grouped-gemm benchmark path is well-structured with correctly separated DISPATCH and PURE timing rows.
transformer_engine/common/CMakeLists.txt New .cu file correctly added to both arch-specific and CUTLASS_KERNEL_SOURCES lists.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[nvte_multi_tensor_gemm] --> B{is_blackwell AND env=1?}
    B -- No --> G[existing cuBLAS path]
    B -- Yes --> C{eligible: NVFP4 per-tensor, shapes pct128, no gelu}
    C -- No --> G
    C -- Yes --> D[cudaMallocAsync alpha_buf, per-group alpha compute, skip empty experts]
    D --> E{num_nonempty == 0?}
    E -- Yes --> F[cudaFreeAsync + return]
    E -- No --> H[nvfp4_cutlass run_grouped_per_tensor_gemm]
    H --> I{fp32_output? has_bias?}
    I -- BF16 plus bias --> J[run_impl bf16 kHasBias=true]
    I -- FP32 --> K[run_impl float kHasBias=false]
    I -- BF16 no bias --> L[run_impl bf16 kHasBias=false]
    J & K & L --> M[Pack metadata to persistent device buffer, single H2D copy]
    M --> N[CUTLASS GemmUniversalAdapter: can_implement, initialize, run]
    N --> O[cudaFreeAsync alpha_buf + return]
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    A[nvte_multi_tensor_gemm] --> B{is_blackwell AND env=1?}
    B -- No --> G[existing cuBLAS path]
    B -- Yes --> C{eligible: NVFP4 per-tensor, shapes pct128, no gelu}
    C -- No --> G
    C -- Yes --> D[cudaMallocAsync alpha_buf, per-group alpha compute, skip empty experts]
    D --> E{num_nonempty == 0?}
    E -- Yes --> F[cudaFreeAsync + return]
    E -- No --> H[nvfp4_cutlass run_grouped_per_tensor_gemm]
    H --> I{fp32_output? has_bias?}
    I -- BF16 plus bias --> J[run_impl bf16 kHasBias=true]
    I -- FP32 --> K[run_impl float kHasBias=false]
    I -- BF16 no bias --> L[run_impl bf16 kHasBias=false]
    J & K & L --> M[Pack metadata to persistent device buffer, single H2D copy]
    M --> N[CUTLASS GemmUniversalAdapter: can_implement, initialize, run]
    N --> O[cudaFreeAsync alpha_buf + return]
Loading

Comments Outside Diff (1)

  1. transformer_engine/common/gemm/cublaslt_gemm.cu, line 848-850 (link)

    P2 Per-call cudaMallocAsync for alpha buffer

    alpha_buf is allocated with cudaMallocAsync on every dispatch call and freed with cudaFreeAsync after the GEMM. Stream-ordered pool allocations are cheap, but for latency-sensitive paths (especially with many small groups), this is a repeated small-allocation pattern the persistent-buffer mechanism was specifically designed to avoid. Adding alpha_buf to a third slot in persistent_buffer() inside nvfp4_cutlass_grouped_gemm.cu (or passing it as a persistent device buffer from the caller) would eliminate this overhead on the hot path.

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +189 to +213
static void *persistent_buffer(size_t bytes, cudaStream_t stream, int which) {
static void *bufs[2] = {nullptr, nullptr};
static size_t caps[2] = {0, 0};
if (bytes > caps[which]) {
if (bufs[which] != nullptr) {
NVTE_CHECK_CUDA(cudaFreeAsync(bufs[which], stream));
}
const size_t newcap = bytes + bytes / 2; // slack to avoid frequent regrows
NVTE_CHECK_CUDA(cudaMallocAsync(&bufs[which], newcap, stream));
caps[which] = newcap;
}
return bufs[which];
}

// Reusable pageable host staging buffer for the single batched H2D copy of all
// per-group metadata. Pageable (not pinned) is intentional: cudaMemcpyAsync
// stages pageable source into a driver buffer before returning, so the host
// buffer can be safely overwritten by the next launch without extra sync.
static void *persistent_host_buffer(size_t bytes) {
static std::vector<uint8_t> buf;
if (buf.size() < bytes) {
buf.resize(bytes + bytes / 2);
}
return buf.data();
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Non-thread-safe static persistent buffers

persistent_buffer() and persistent_host_buffer() both mutate static local variables (bufs[], caps[], buf) without any synchronization. The comment says "Assumes grouped GEMMs are issued serially on one stream (the TE norm)" — but CUDA stream ordering is GPU-side; two CPU threads can simultaneously reach the if (bytes > caps[which]) branch, both see a stale capacity, and both issue cudaFreeAsync on the same pointer followed by competing cudaMallocAsync writes back into bufs[which]. This produces a double-free and a dangling pointer even under the single-stream assumption. Multi-threaded callers (e.g. torch.compile with threading, or TP frameworks that invoke GEMM from multiple CPU threads) would be vulnerable. Adding a std::mutex around the resize block (or using the same mutex that guards the adjacent cuBLAS handle manager) would fix this.

Comment on lines +333 to +335
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count = cached_sm_count();

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 hw_info.device_id is hardcoded to 0 regardless of which GPU the current process is using. The existing Hopper CUTLASS grouped-GEMM path (cutlass_grouped_gemm.cuh) correctly propagates current_device down from the dispatcher (see cublaslt_gemm.cu line 1277), and cached_sm_count() also always queries device 0. In a multi-GPU process where the caller has selected a non-zero device, the SM count used for tile scheduling will be wrong. Please use transformer_engine::cuda::current_device() (or equivalent) to resolve the correct device, matching the existing pattern.

Suggested change
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count = cached_sm_count();
const int cur_device = transformer_engine::cuda::current_device();
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = cur_device;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(cur_device);

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant