Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
357 changes: 356 additions & 1 deletion benchmarks/linear/benchmark_grouped_linear.py

Large diffs are not rendered by default.

325 changes: 325 additions & 0 deletions tests/pytorch/test_grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,6 +920,331 @@ def test_grouped_gemm_cutlass_empty_groups(layout, monkeypatch):
torch.testing.assert_close(tensor, torch.zeros_like(tensor), rtol=0, atol=0)


# =============================================================================
# NVFP4 per-tensor single-launch CUTLASS grouped GEMM (Blackwell / SM100).
#
# Opt-in via NVTE_NVFP4_CUTLASS_GROUPED_GEMM; a drop-in replacement for the
# production multi-stream cuBLASLt per-expert loop. These tests mirror the BF16
# Hopper cutlass coverage above -- GEMM-level (test_grouped_gemm), empty groups
# (test_grouped_gemm_cutlass_empty_groups) and module-level
# (test_grouped_linear_accuracy_cutlass) -- but for the NVFP4 per-tensor path,
# which is SM100-only and additionally fuses bias (fprop) / accumulate (wgrad).
# =============================================================================
_NVFP4_CUTLASS_ENV = "NVTE_NVFP4_CUTLASS_GROUPED_GEMM"
nvfp4_cutlass_grouped_available = nvfp4_available and torch.cuda.get_device_capability()[0] == 10


def _nvfp4_pertensor_quantize(hp: torch.Tensor):
"""High-precision -> per-tensor NVFP4 (rowwise+columnwise, no RHT/SR/2D) so
operand canonicalization can pick either orientation and a plain dequantize()
is a faithful neutral reference. A 0-row operand (empty expert) skips the
cast but still returns a valid NVFP4 tensor of the right shape."""
q = NVFP4Quantizer(
fp4_dtype=tex.DType.kFloat4E2M1,
rowwise=True,
columnwise=True,
with_rht=False,
with_post_rht_amax=False,
with_2d_quantization=False,
stochastic_rounding=False,
with_random_sign_mask=False,
)
dst = q.make_empty(hp.shape, dtype=torch.bfloat16, device=hp.device)
if hp.numel() != 0:
tex.quantize(hp, q, dst, None)
return dst


def _diff(ref: torch.Tensor, test: torch.Tensor):
"""(max_abs, global_inf_norm_rel, ref_inf). The global rel (max_abs / ||ref||)
is robust to the near-zero output elements that make per-element rel explode."""
ref = ref.float()
test = test.float()
max_abs = (ref - test).abs().max().item()
ref_inf = ref.abs().max().item()
return max_abs, max_abs / max(ref_inf, 1e-6), ref_inf


def _nvfp4_dequant_reference(A, B, *, layout: str, bias=None, init=None):
"""Independent fp32 reference: dequantize the NVFP4 operands and matmul in
fp32, WITHOUT either GEMM backend. Catches a wrong orientation / scale even
when cutlass and multi-stream agree bit-for-bit. The rowwise dequant differs
from the orientation the kernel consumes (esp. NN/NT), so it is a few % off
*both* backends -- the test is that cutlass is no less accurate than the
production path, not that it beats this reference."""
refs = []
for g, (a, b) in enumerate(zip(A, B)):
ad = a.dequantize(dtype=torch.float32)
bd = b.dequantize(dtype=torch.float32)
if layout == "TN": # out(m,N) = X(m,K) @ W(N,K)^T ; A=W, B=X
r = bd @ ad.t()
if bias is not None:
r = r + bias[g].float().unsqueeze(0)
elif layout == "NN": # out(m,K) = dY(m,N) @ W(N,K) ; A=W, B=dY
r = bd @ ad
else: # NT: out(N,K) = dY(m,N)^T @ X(m,K) ; A=X, B=dY
r = bd.t() @ ad
if init is not None:
r = r + init[g].float()
refs.append(r)
if layout in ("TN", "NN"):
return [torch.cat(refs, dim=0)] # single_output: concat over groups (M)
return refs


def _run_nvfp4_grouped(
A, B, out, *, layout, grad, accumulate, m_splits, single_output, bias, cutlass, monkeypatch
):
monkeypatch.setenv(_NVFP4_CUTLASS_ENV, "1" if cutlass else "0")
general_grouped_gemm(
A,
B,
out,
[None] * len(A),
out[0].dtype,
layout=layout,
m_splits=m_splits,
single_output=single_output,
grad=grad,
accumulate=accumulate,
bias=bias,
use_bias=bias is not None,
)


def _assert_nvfp4_grouped_parity(out_ms, out_cu, hp):
"""out_ms = multi-stream (reference), out_cu = cutlass, hp = neutral dequant
reference. Asserts (1) cutlass tracks multi-stream and (2) cutlass is no less
accurate than multi-stream against the neutral reference."""
for ms, cu, h in zip(out_ms, out_cu, hp):
if cu.numel() == 0:
continue
abs_d, rel_d, ms_inf = _diff(ms, cu)
assert ms_inf > 1e-6, "reference output is ~0 (operand/quant bug, not a real check)"
# Backend consistency: overwrite paths are bit-identical; bias / accumulate
# add a ~1 ULP fp32/bf16 rounding diff. Allow either bound.
assert (
abs_d <= 5e-2 or rel_d <= 2e-2
), f"cutlass vs multi-stream diverged: max_abs={abs_d:.4g}, rel={rel_d:.4g}"
# Correctness: cutlass no worse than production vs the neutral reference.
_, ms_hp, _ = _diff(h, ms)
_, cu_hp, _ = _diff(h, cu)
assert cu_hp <= max(
0.05, ms_hp * 1.3
), f"cutlass less accurate than multi-stream vs dequant ref: cu={cu_hp:.4g}, ms={ms_hp:.4g}"


def _build_nvfp4_grouped_operands(layout, m_splits, k, n, *, accumulate, use_bias, odt, dev):
"""Quantize per-tensor NVFP4 operands for one capability case. Returns
(A, B, out_ms, out_cu, grad, single_output, bias, init). out_ms/out_cu are
two independent buffers (same init) so multi-stream and cutlass never alias."""
z = len(m_splits)
m = sum(m_splits)

def qt(*sz):
return _nvfp4_pertensor_quantize(torch.randn(*sz, dtype=torch.bfloat16, device=dev) * 0.5)

bias = init = None
if layout == "TN": # fprop: A=W(n,k), B=X(m,k) -> out(m,n)
A = [qt(n, k) for _ in range(z)]
B = [qt(mm, k) for mm in m_splits]
out_ms = [torch.zeros(m, n, dtype=odt, device=dev)]
out_cu = [torch.zeros(m, n, dtype=odt, device=dev)]
grad, single_output = False, True
if use_bias:
bias = [torch.randn(n, dtype=odt, device=dev) for _ in range(z)]
elif layout == "NN": # dgrad: A=W(n,k), B=dY(m,n) -> dgrad(m,k)
A = [qt(n, k) for _ in range(z)]
B = [qt(mm, n) for mm in m_splits]
out_ms = [torch.zeros(m, k, dtype=odt, device=dev)]
out_cu = [torch.zeros(m, k, dtype=odt, device=dev)]
grad, single_output = True, True
else: # NT wgrad: A=X(m,k), B=dY(m,n) -> wgrad(n,k) per group
A = [qt(mm, k) for mm in m_splits]
B = [qt(mm, n) for mm in m_splits]
if accumulate: # in-place main_grad accumulate: both backends share init
init = [torch.randn(n, k, dtype=odt, device=dev) for _ in range(z)]
out_ms = [t.clone() for t in init]
out_cu = [t.clone() for t in init]
else:
out_ms = [torch.zeros(n, k, dtype=odt, device=dev) for _ in range(z)]
out_cu = [torch.zeros(n, k, dtype=odt, device=dev) for _ in range(z)]
grad, single_output = True, False
return A, B, out_ms, out_cu, grad, single_output, bias, init


def _run_nvfp4_gemm_case(layout, fp32_out, accumulate, use_bias, m_splits, k, n, monkeypatch):
"""Build operands once, run both backends, assert parity + correctness."""
torch.manual_seed(0)
dev = "cuda"
odt = torch.float32 if fp32_out else torch.bfloat16
A, B, out_ms, out_cu, grad, single_output, bias, init = _build_nvfp4_grouped_operands(
layout, m_splits, k, n, accumulate=accumulate, use_bias=use_bias, odt=odt, dev=dev
)
_run_nvfp4_grouped(
A,
B,
out_ms,
layout=layout,
grad=grad,
accumulate=accumulate,
m_splits=m_splits,
single_output=single_output,
bias=bias,
cutlass=False,
monkeypatch=monkeypatch,
)
_run_nvfp4_grouped(
A,
B,
out_cu,
layout=layout,
grad=grad,
accumulate=accumulate,
m_splits=m_splits,
single_output=single_output,
bias=bias,
cutlass=True,
monkeypatch=monkeypatch,
)
hp = _nvfp4_dequant_reference(A, B, layout=layout, bias=bias, init=init)
_assert_nvfp4_grouped_parity(out_ms, out_cu, hp)


# (id, layout, fp32_out, accumulate, use_bias) -- the capability matrix of the
# per-tensor path: TN fprop (+bias), NN dgrad, NT wgrad (fresh + accumulate).
_NVFP4_GEMM_CASES = [
("fprop", "TN", False, False, False),
("fprop_bias", "TN", False, False, True),
("dgrad", "NN", False, False, False),
("wgrad", "NT", True, False, False),
("wgrad_accum", "NT", True, True, False),
]


@pytest.mark.skipif(
not nvfp4_cutlass_grouped_available,
reason="NVFP4 CUTLASS grouped GEMM requires Blackwell (SM100) + NVFP4",
)
@pytest.mark.parametrize(
"shape",
[(4, 512, 256, 256), (8, 1024, 512, 256), (2, 256, 256, 512)],
ids=lambda s: f"{s[0]}x{s[1]}_{s[2]}x{s[3]}",
)
@pytest.mark.parametrize("case", _NVFP4_GEMM_CASES, ids=lambda c: c[0])
def test_nvfp4_cutlass_grouped_gemm(shape, case, monkeypatch):
"""GEMM-level parity. Operands are quantized ONCE (per-tensor NVFP4, all dims
%128 so the path is eligible) and fed to both backends, so any divergence is
a GEMM-backend bug, not a quant artifact. Uniform (equal) splits."""
_, layout, fp32_out, accumulate, use_bias = case
z, m, k, n = shape
m_per = m // z
assert m_per * z == m and m_per % 128 == 0 and k % 128 == 0 and n % 128 == 0
_run_nvfp4_gemm_case(layout, fp32_out, accumulate, use_bias, [m_per] * z, k, n, monkeypatch)


@pytest.mark.skipif(
not nvfp4_cutlass_grouped_available,
reason="NVFP4 CUTLASS grouped GEMM requires Blackwell (SM100) + NVFP4",
)
@pytest.mark.parametrize("case", _NVFP4_GEMM_CASES, ids=lambda c: c[0])
def test_nvfp4_cutlass_grouped_gemm_uneven_splits(case, monkeypatch):
"""Same capability matrix but with unequal (still 128-aligned) per-group
sizes, mirroring the uneven token distribution real MoE routing produces.
Stresses the per-group offset / pointer setup that equal splits cannot."""
_, layout, fp32_out, accumulate, use_bias = case
m_splits = [128, 256, 384, 512] # all %128, deliberately unequal
_run_nvfp4_gemm_case(layout, fp32_out, accumulate, use_bias, m_splits, 256, 256, monkeypatch)


@pytest.mark.skipif(
not nvfp4_cutlass_grouped_available,
reason="NVFP4 CUTLASS grouped GEMM requires Blackwell (SM100) + NVFP4",
)
@pytest.mark.parametrize("layout", ["TN", "NN"])
def test_nvfp4_cutlass_grouped_gemm_empty_groups(layout, monkeypatch):
"""Empty experts (0-token groups). Tokens map to cuBLAS n for TN/NN, so an
empty group is n=0; CUTLASS schedules 0 tiles for it and the non-empty groups
must still compute (a single empty expert must NOT veto the batch). NT is
omitted: an empty NT group is K=0, which the dispatcher keeps on multi-stream
by design."""
m_splits = [0, 128, 0, 128] # mix of empty / non-empty experts
_run_nvfp4_gemm_case(layout, False, False, False, m_splits, 256, 256, monkeypatch)


@pytest.mark.skipif(
not nvfp4_cutlass_grouped_available,
reason="NVFP4 CUTLASS grouped GEMM requires Blackwell (SM100) + NVFP4",
)
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
@pytest.mark.parametrize("groups", [3, 6])
def test_nvfp4_cutlass_grouped_linear(groups, bias, fuse_wgrad_accumulation, monkeypatch):
"""End-to-end GroupedLinear fwd+bwd: cutlass (env=1) vs multi-stream (env=0)
through the real module wiring under a per-tensor NVFP4 recipe. Identical
weights / x / dy (RNG reset before each run) feed both, so any divergence in
output, dgrad, wgrad or dbias is a backend bug. Exercises fprop/dgrad/wgrad,
fused bias and the main_grad (fuse_wgrad_accumulation) accumulate path, across
odd/even group counts with unequal (128-aligned) per-expert token counts."""
nvfp4_recipe = recipe.NVFP4BlockScaling(
disable_rht=True,
disable_stochastic_rounding=True,
disable_2d_quantization=True,
)
K, N = 2048, 2048 # all dims %128 -> path eligible
m_splits = [128 * ((i % 3) + 1) for i in range(groups)] # unequal, 128-aligned
total_m = sum(m_splits)

torch.manual_seed(0)
model = GroupedLinear(
groups,
K,
N,
bias=bias,
params_dtype=torch.bfloat16,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
x = torch.randn(total_m, K, dtype=torch.bfloat16, device="cuda", requires_grad=True)
dy = torch.randn(total_m, N, dtype=torch.bfloat16, device="cuda")
init_mg = (
[torch.randn(N, K, dtype=torch.float32, device="cuda") for _ in range(groups)]
if fuse_wgrad_accumulation
else None
)

def run(cutlass: bool):
monkeypatch.setenv(_NVFP4_CUTLASS_ENV, "1" if cutlass else "0")
reset_rng_states() # identical quantization randoms in both runs
model.zero_grad(set_to_none=True)
if x.grad is not None:
x.grad = None
if fuse_wgrad_accumulation:
for i in range(groups):
getattr(model, f"weight{i}").main_grad = init_mg[i].clone()
with autocast(enabled=True, recipe=nvfp4_recipe):
out = model(x, m_splits)
out.backward(dy)
snap = {"out": out.detach().float().clone(), "dgrad": x.grad.detach().float().clone()}
for i in range(groups):
w = getattr(model, f"weight{i}")
g = w.main_grad if fuse_wgrad_accumulation else w.grad
snap[f"wgrad{i}"] = g.detach().float().clone()
if bias:
b = getattr(model, f"bias{i}")
if b.grad is not None:
snap[f"dbias{i}"] = b.grad.detach().float().clone()
return snap

ref = run(cutlass=False)
test = run(cutlass=True)
for key in ref:
abs_d, rel_d, _ = _diff(ref[key], test[key])
assert (
abs_d <= 5e-2 or rel_d <= 2e-2
), f"{key}: cutlass vs multi-stream diverged (max_abs={abs_d:.4g}, rel={rel_d:.4g})"


def _pack_grouped_tensor(grouped_tensor: GroupedTensor, tensors: List[torch.Tensor]) -> None:
data = grouped_tensor.rowwise_data
if data is None:
Expand Down
2 changes: 2 additions & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ list(APPEND transformer_engine_cuda_arch_specific_sources
cast/cast_grouped.cu
cast/cast_grouped_dbias.cu
gemm/cutlass_grouped_gemm.cu
gemm/nvfp4_cutlass_grouped_gemm.cu
hadamard_transform/group_hadamard_transform.cu
hadamard_transform/graph_safe_group_hadamard_transform.cu
hadamard_transform/hadamard_transform.cu
Expand Down Expand Up @@ -345,6 +346,7 @@ set_property(
# CUTLASS kernels could cause hang in debug build
set(CUTLASS_KERNEL_SOURCES
gemm/cutlass_grouped_gemm.cu
gemm/nvfp4_cutlass_grouped_gemm.cu
hadamard_transform/group_hadamard_transform_cast_fusion.cu
hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu
hadamard_transform/hadamard_transform_cast_fusion.cu)
Expand Down
Loading