Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 41 additions & 5 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3435,8 +3435,9 @@ def test_grouped_mlp(
quantization: Optional[str],
device: torch.device = "cuda",
split_alignment: int = 256,
activation: str = "scaled_swiglu",
) -> None:
"""GroupedLinear + ScaledSwiGLU + GroupedLinear"""
"""GroupedLinear + scaled activation + GroupedLinear"""

# Split sizes
split_sizes = [split_alignment * (i) for i in range(group_size)]
Expand All @@ -3446,13 +3447,20 @@ def test_grouped_mlp(
# Make input shape
in_shape = (split_sizes.sum().item(), hidden_size)
out_shape = in_shape
fc1_out_features = 2 * hidden_size
if activation == "scaled_swiglu":
fc1_out_features = 2 * hidden_size
elif activation == "scaled_srelu":
fc1_out_features = hidden_size
else:
raise ValueError(f"Unexpected grouped MLP activation ({activation})")

# Skip invalid configurations
with_quantization = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
if with_quantization and dtype not in (torch.bfloat16, torch.float16):
pytest.skip("Quantized group GEMM is only supported with BF16/FP16")
if activation == "scaled_srelu" and quantization == "nvfp4_rht" and bias:
pytest.skip("NVFP4 RHT SReLU grouped MLP coverage is limited to no-bias")
Comment thread
sraman-rgb marked this conversation as resolved.

# Random data
x_ref, x_test = make_reference_and_test_tensors(
Expand Down Expand Up @@ -3535,8 +3543,13 @@ def test_grouped_mlp(
fc1_out = torch.nn.functional.linear(
x, fc1_ws_ref[group_idx], bias=fc1_bs_ref[group_idx]
)
act_in1, act_in2 = fc1_out.chunk(2, dim=-1)
act_out = torch.nn.functional.silu(act_in1) * act_in2
if activation == "scaled_swiglu":
act_in1, act_in2 = fc1_out.chunk(2, dim=-1)
act_out = torch.nn.functional.silu(act_in1) * act_in2
elif activation == "scaled_srelu":
act_out = torch.nn.functional.relu(fc1_out).square()
else:
raise ValueError(f"Unexpected grouped MLP activation ({activation})")
fc2_in = act_out * probs[group_idx].unsqueeze(-1)
y = torch.nn.functional.linear(fc2_in, fc2_ws_ref[group_idx])
if bias:
Expand Down Expand Up @@ -3565,7 +3578,13 @@ def test_grouped_mlp(
dtype=dtype,
scale_bias=bias,
)
module = te.ops.Sequential(fc1, te_ops.ScaledSwiGLU(), fc2)
if activation == "scaled_swiglu":
activation_op = te_ops.ScaledSwiGLU()
elif activation == "scaled_srelu":
activation_op = te_ops.ScaledSReLU()
else:
raise ValueError(f"Unexpected grouped MLP activation ({activation})")
module = te.ops.Sequential(fc1, activation_op, fc2)

# Copy weights
with torch.no_grad():
Expand All @@ -3585,6 +3604,8 @@ def test_grouped_mlp(

# Loose tols for sanity checking
tols = {"rtol": 0.125, "atol": 0.25}
if quantization == "nvfp4_rht":
tols = {"rtol": 0.25, "atol": 0.5}

# Check values
assert_close(y_test, y_ref, **tols)
Expand All @@ -3597,6 +3618,21 @@ def test_grouped_mlp(
assert_close_grads(getattr(fc2, f"bias{group_idx}"), fc2_bs_ref[group_idx], **tols)
assert_close_grads(getattr(fc1, f"bias{group_idx}"), fc1_bs_ref[group_idx], **tols)

def test_grouped_mlp_nvfp4_rht_srelu(
self,
*,
device: torch.device = "cuda",
) -> None:
"""GroupedLinear + ScaledSReLU + GroupedLinear with NVFP4 RHT amax."""

self.test_grouped_mlp(
bias=False,
dtype=torch.bfloat16,
quantization="nvfp4_rht",
device=device,
activation="scaled_srelu",
)


class TestCustomOps:
"""Test with ops that are defined externally"""
Expand Down
42 changes: 30 additions & 12 deletions transformer_engine/pytorch/ops/fused/grouped_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,7 @@ def fuser_forward(
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
# Get basic operations
fc1_op, _, fc2_op = self.basic_ops
fc1_op, activation_op, fc2_op = self.basic_ops
fc1_ctx, _activation_ctx, fc2_ctx = basic_op_ctxs

# Tensor properties
Expand Down Expand Up @@ -1151,17 +1151,20 @@ def fuser_forward(
fc1_alpha_tensor = alpha_tensor

use_tmem_post_rht_amax = _use_tmem_post_rht_amax()
use_fc1_glu_hadamard = False
use_fc1_act_hadamard = False
use_fc1_act_hadamard_srelu = False
use_nvfp4_rht_amax = (
use_nvfp4
and isinstance(fc2_input_quantizer, NVFP4Quantizer)
and fc2_input_quantizer.with_rht
and fc2_input_quantizer.with_post_rht_amax
)
if use_nvfp4_rht_amax and self._cudnn_act_func == "swiglu":
kernel_getter = getattr(self, "grouped_gemm_glu_hadamard_kernel", None)
activation_is_srelu = isinstance(activation_op, ScaledSReLU)
if use_nvfp4_rht_amax and (self._cudnn_act_func == "swiglu" or activation_is_srelu):
kernel_getter = getattr(self, "grouped_gemm_act_hadamard_kernel", None)
if kernel_getter is not None:
use_fc1_glu_hadamard = kernel_getter() is not None
use_fc1_act_hadamard = kernel_getter() is not None
use_fc1_act_hadamard_srelu = use_fc1_act_hadamard and activation_is_srelu

fc1_activation_kwargs = {
"a_tensor": fc1_x_data,
Expand All @@ -1178,9 +1181,11 @@ def fuser_forward(
"current_stream": current_stream,
"use_dynamic_sched": True,
}
if self._cudnn_act_func is not None:
if use_fc1_act_hadamard_srelu:
fc1_activation_kwargs["act_func"] = "srelu"
elif self._cudnn_act_func is not None:
fc1_activation_kwargs["act_func"] = self._cudnn_act_func
if use_fc1_glu_hadamard:
if use_fc1_act_hadamard:
fc1_activation_kwargs["use_tmem_post_rht_amax"] = use_tmem_post_rht_amax
else:
fc1_activation_kwargs["norm_const_tensor"] = fc1_norm_const_tensor
Expand Down Expand Up @@ -1234,8 +1239,8 @@ def fuser_forward(
fc1_activation_kwargs["b_dtype"] = data_dtype
fc1_activation_kwargs["b_major"] = "k"

if use_fc1_glu_hadamard:
fc1_kernel_out = self.grouped_gemm_glu_hadamard_kernel()(**fc1_activation_kwargs)
if use_fc1_act_hadamard:
fc1_kernel_out = self.grouped_gemm_act_hadamard_kernel()(**fc1_activation_kwargs)
else:
fc1_kernel_out = self.grouped_gemm_activation_kernel()(**fc1_activation_kwargs)

Expand Down Expand Up @@ -1269,7 +1274,7 @@ def fuser_forward(
fc2_in = fc2_in.view(in_shape[0], fc2_weight_shape[1]).contiguous()
fc2_input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
fc2_input_quantizer.optimize_for_gemm = True
if use_fc1_glu_hadamard:
if use_fc1_act_hadamard:
grouped_fc2_x = _group_quantize_with_amax_for_grouped_mlp(
fc2_in,
fc2_input_quantizer,
Expand Down Expand Up @@ -2109,8 +2114,8 @@ def grouped_gemm_activation_kernel(cls) -> Callable:

@classmethod
@functools.lru_cache(maxsize=None)
def grouped_gemm_glu_hadamard_kernel(cls) -> Optional[Callable]:
"""Fused grouped GEMM GLU kernel that also emits NVFP4 RHT amaxes."""
def grouped_gemm_act_hadamard_kernel(cls) -> Optional[Callable]:
"""Fused grouped GEMM activation kernel that also emits NVFP4 RHT amaxes."""
try:
from cudnn import (
grouped_gemm_glu_hadamard_wrapper_sm100,
Expand Down Expand Up @@ -2146,6 +2151,19 @@ def grouped_gemm_activation_kernel(cls) -> Callable:

return grouped_gemm_srelu_wrapper_sm100

@classmethod
@functools.lru_cache(maxsize=None)
def grouped_gemm_act_hadamard_kernel(cls) -> Optional[Callable]:
"""Fused grouped GEMM activation kernel that also emits NVFP4 RHT amaxes."""
try:
from cudnn import (
grouped_gemm_glu_hadamard_wrapper_sm100,
) # pylint: disable=no-name-in-module,import-outside-toplevel
except ImportError:
return None

return grouped_gemm_glu_hadamard_wrapper_sm100

@classmethod
@functools.lru_cache(maxsize=None)
def grouped_gemm_dactivation_kernel(cls) -> Callable:
Expand Down
Loading