From 79def34966bc227cbc459efec0be1304877010de Mon Sep 17 00:00:00 2001 From: Siddhartha Raman Date: Tue, 16 Jun 2026 11:39:05 -0700 Subject: [PATCH] Enable NVFP4 RHT amax for grouped SReLU MLP Signed-off-by: Siddhartha Raman --- tests/pytorch/test_fusible_ops.py | 46 +++++++++++++++++-- .../pytorch/ops/fused/grouped_mlp.py | 42 ++++++++++++----- 2 files changed, 71 insertions(+), 17 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 43c7965518..abfc0f75f6 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -3435,8 +3435,9 @@ def test_grouped_mlp( quantization: Optional[str], device: torch.device = "cuda", split_alignment: int = 256, + activation: str = "scaled_swiglu", ) -> None: - """GroupedLinear + ScaledSwiGLU + GroupedLinear""" + """GroupedLinear + scaled activation + GroupedLinear""" # Split sizes split_sizes = [split_alignment * (i) for i in range(group_size)] @@ -3446,13 +3447,20 @@ def test_grouped_mlp( # Make input shape in_shape = (split_sizes.sum().item(), hidden_size) out_shape = in_shape - fc1_out_features = 2 * hidden_size + if activation == "scaled_swiglu": + fc1_out_features = 2 * hidden_size + elif activation == "scaled_srelu": + fc1_out_features = hidden_size + else: + raise ValueError(f"Unexpected grouped MLP activation ({activation})") # Skip invalid configurations with_quantization = quantization is not None maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) if with_quantization and dtype not in (torch.bfloat16, torch.float16): pytest.skip("Quantized group GEMM is only supported with BF16/FP16") + if activation == "scaled_srelu" and quantization == "nvfp4_rht" and bias: + pytest.skip("NVFP4 RHT SReLU grouped MLP coverage is limited to no-bias") # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -3535,8 +3543,13 @@ def test_grouped_mlp( fc1_out = torch.nn.functional.linear( x, fc1_ws_ref[group_idx], bias=fc1_bs_ref[group_idx] ) - act_in1, act_in2 = fc1_out.chunk(2, dim=-1) - act_out = torch.nn.functional.silu(act_in1) * act_in2 + if activation == "scaled_swiglu": + act_in1, act_in2 = fc1_out.chunk(2, dim=-1) + act_out = torch.nn.functional.silu(act_in1) * act_in2 + elif activation == "scaled_srelu": + act_out = torch.nn.functional.relu(fc1_out).square() + else: + raise ValueError(f"Unexpected grouped MLP activation ({activation})") fc2_in = act_out * probs[group_idx].unsqueeze(-1) y = torch.nn.functional.linear(fc2_in, fc2_ws_ref[group_idx]) if bias: @@ -3565,7 +3578,13 @@ def test_grouped_mlp( dtype=dtype, scale_bias=bias, ) - module = te.ops.Sequential(fc1, te_ops.ScaledSwiGLU(), fc2) + if activation == "scaled_swiglu": + activation_op = te_ops.ScaledSwiGLU() + elif activation == "scaled_srelu": + activation_op = te_ops.ScaledSReLU() + else: + raise ValueError(f"Unexpected grouped MLP activation ({activation})") + module = te.ops.Sequential(fc1, activation_op, fc2) # Copy weights with torch.no_grad(): @@ -3585,6 +3604,8 @@ def test_grouped_mlp( # Loose tols for sanity checking tols = {"rtol": 0.125, "atol": 0.25} + if quantization == "nvfp4_rht": + tols = {"rtol": 0.25, "atol": 0.5} # Check values assert_close(y_test, y_ref, **tols) @@ -3597,6 +3618,21 @@ def test_grouped_mlp( assert_close_grads(getattr(fc2, f"bias{group_idx}"), fc2_bs_ref[group_idx], **tols) assert_close_grads(getattr(fc1, f"bias{group_idx}"), fc1_bs_ref[group_idx], **tols) + def test_grouped_mlp_nvfp4_rht_srelu( + self, + *, + device: torch.device = "cuda", + ) -> None: + """GroupedLinear + ScaledSReLU + GroupedLinear with NVFP4 RHT amax.""" + + self.test_grouped_mlp( + bias=False, + dtype=torch.bfloat16, + quantization="nvfp4_rht", + device=device, + activation="scaled_srelu", + ) + class TestCustomOps: """Test with ops that are defined externally""" diff --git a/transformer_engine/pytorch/ops/fused/grouped_mlp.py b/transformer_engine/pytorch/ops/fused/grouped_mlp.py index 39180f098e..d83311b6a5 100644 --- a/transformer_engine/pytorch/ops/fused/grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/grouped_mlp.py @@ -871,7 +871,7 @@ def fuser_forward( basic_op_kwargs: list[dict[str, Any]], ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: # Get basic operations - fc1_op, _, fc2_op = self.basic_ops + fc1_op, activation_op, fc2_op = self.basic_ops fc1_ctx, _activation_ctx, fc2_ctx = basic_op_ctxs # Tensor properties @@ -1151,17 +1151,20 @@ def fuser_forward( fc1_alpha_tensor = alpha_tensor use_tmem_post_rht_amax = _use_tmem_post_rht_amax() - use_fc1_glu_hadamard = False + use_fc1_act_hadamard = False + use_fc1_act_hadamard_srelu = False use_nvfp4_rht_amax = ( use_nvfp4 and isinstance(fc2_input_quantizer, NVFP4Quantizer) and fc2_input_quantizer.with_rht and fc2_input_quantizer.with_post_rht_amax ) - if use_nvfp4_rht_amax and self._cudnn_act_func == "swiglu": - kernel_getter = getattr(self, "grouped_gemm_glu_hadamard_kernel", None) + activation_is_srelu = isinstance(activation_op, ScaledSReLU) + if use_nvfp4_rht_amax and (self._cudnn_act_func == "swiglu" or activation_is_srelu): + kernel_getter = getattr(self, "grouped_gemm_act_hadamard_kernel", None) if kernel_getter is not None: - use_fc1_glu_hadamard = kernel_getter() is not None + use_fc1_act_hadamard = kernel_getter() is not None + use_fc1_act_hadamard_srelu = use_fc1_act_hadamard and activation_is_srelu fc1_activation_kwargs = { "a_tensor": fc1_x_data, @@ -1178,9 +1181,11 @@ def fuser_forward( "current_stream": current_stream, "use_dynamic_sched": True, } - if self._cudnn_act_func is not None: + if use_fc1_act_hadamard_srelu: + fc1_activation_kwargs["act_func"] = "srelu" + elif self._cudnn_act_func is not None: fc1_activation_kwargs["act_func"] = self._cudnn_act_func - if use_fc1_glu_hadamard: + if use_fc1_act_hadamard: fc1_activation_kwargs["use_tmem_post_rht_amax"] = use_tmem_post_rht_amax else: fc1_activation_kwargs["norm_const_tensor"] = fc1_norm_const_tensor @@ -1234,8 +1239,8 @@ def fuser_forward( fc1_activation_kwargs["b_dtype"] = data_dtype fc1_activation_kwargs["b_major"] = "k" - if use_fc1_glu_hadamard: - fc1_kernel_out = self.grouped_gemm_glu_hadamard_kernel()(**fc1_activation_kwargs) + if use_fc1_act_hadamard: + fc1_kernel_out = self.grouped_gemm_act_hadamard_kernel()(**fc1_activation_kwargs) else: fc1_kernel_out = self.grouped_gemm_activation_kernel()(**fc1_activation_kwargs) @@ -1269,7 +1274,7 @@ def fuser_forward( fc2_in = fc2_in.view(in_shape[0], fc2_weight_shape[1]).contiguous() fc2_input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) fc2_input_quantizer.optimize_for_gemm = True - if use_fc1_glu_hadamard: + if use_fc1_act_hadamard: grouped_fc2_x = _group_quantize_with_amax_for_grouped_mlp( fc2_in, fc2_input_quantizer, @@ -2109,8 +2114,8 @@ def grouped_gemm_activation_kernel(cls) -> Callable: @classmethod @functools.lru_cache(maxsize=None) - def grouped_gemm_glu_hadamard_kernel(cls) -> Optional[Callable]: - """Fused grouped GEMM GLU kernel that also emits NVFP4 RHT amaxes.""" + def grouped_gemm_act_hadamard_kernel(cls) -> Optional[Callable]: + """Fused grouped GEMM activation kernel that also emits NVFP4 RHT amaxes.""" try: from cudnn import ( grouped_gemm_glu_hadamard_wrapper_sm100, @@ -2146,6 +2151,19 @@ def grouped_gemm_activation_kernel(cls) -> Callable: return grouped_gemm_srelu_wrapper_sm100 + @classmethod + @functools.lru_cache(maxsize=None) + def grouped_gemm_act_hadamard_kernel(cls) -> Optional[Callable]: + """Fused grouped GEMM activation kernel that also emits NVFP4 RHT amaxes.""" + try: + from cudnn import ( + grouped_gemm_glu_hadamard_wrapper_sm100, + ) # pylint: disable=no-name-in-module,import-outside-toplevel + except ImportError: + return None + + return grouped_gemm_glu_hadamard_wrapper_sm100 + @classmethod @functools.lru_cache(maxsize=None) def grouped_gemm_dactivation_kernel(cls) -> Callable: