Skip to content
1 change: 1 addition & 0 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_mxfp8.xml $TE_PATH/tests/pytorch/mxfp8 || test_fail "test_mxfp8"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_quantized_tensor.xml $TE_PATH/tests/pytorch/test_quantized_tensor.py || test_fail "test_quantized_tensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_torch_compile.xml $TE_PATH/tests/pytorch/test_torch_compile.py || test_fail "test_torch_compile.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py"
Expand Down
41 changes: 31 additions & 10 deletions tests/pytorch/test_torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.ops.basic.basic_linear import BasicLinear
from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer
from transformer_engine.pytorch.quantization import QuantizerRole
from transformer_engine.pytorch import (
is_fp8_available,
is_mxfp8_available,
Expand Down Expand Up @@ -123,21 +124,25 @@ def __fx_repr__(self):
_Q = get_opaque_type_name(ToyQuantizer)

def _make_qfactory(tag: str):
"""Return a qfactory that produces ToyQuantizer instances tagged with *tag*."""
"""Return a qfactory that produces ToyQuantizer instances tagged with *tag*.

The factory dispatches on ``QuantizerRole.tensor_type``; the roles are
supplied by :meth:`ToyLinear.get_quantizer_roles`.
"""

quantizers = {
role: ToyQuantizer(tag=f"{tag}:{role}")
for role in (
"linear_input",
"linear_weight",
"linear_output",
"linear_grad_output",
"linear_grad_input",
tensor_type: ToyQuantizer(tag=f"{tag}:{tensor_type}")
for tensor_type in (
"input",
"weight",
"output",
"grad_output",
"grad_input",
)
}

def qfactory(role: str):
return quantizers[role]
def qfactory(role: QuantizerRole):
return quantizers[role.tensor_type]

return qfactory

Expand All @@ -163,6 +168,22 @@ def __init__(
)
torch.nn.init.normal_(self.weight)

def get_quantizer_roles(self, *, fwd: bool, num_quantizers: int):
# Supplying explicit roles keeps CustomRecipeState from emitting a
# warning (which would graph-break under fullgraph=True) and lets the
# qfactory dispatch per tensor slot. Order must match the module's
# quantizer array (FP8FwdTensorIdx / FP8BwdTensorIdx).
if fwd:
return [
QuantizerRole(module_type="linear", tensor_type="input"),
QuantizerRole(module_type="linear", tensor_type="weight"),
QuantizerRole(module_type="linear", tensor_type="output"),
]
return [
QuantizerRole(module_type="linear", tensor_type="grad_output"),
QuantizerRole(module_type="linear", tensor_type="grad_input"),
]

def _get_weight_tensors(self):
return [self.weight]

Expand Down
15 changes: 15 additions & 0 deletions transformer_engine/pytorch/module/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,21 @@
from ..utils import get_default_init_method


def set_quantizer_amax_reduction_group(quantizer, amax_reduction_group) -> None:
"""Set the amax reduction group on a quantizer; no-op if it doesn't support it.

Unwraps ``DebugQuantizer`` to its ``parent_quantizer``, which is the one that
actually performs the quantization (and thus the amax reduction).
"""
if quantizer is None:
return
# DebugQuantizer delegates quantization to parent_quantizer
target = getattr(quantizer, "parent_quantizer", quantizer)
if target is not None and hasattr(target, "with_amax_reduction"):
target.with_amax_reduction = amax_reduction_group is not None
target.amax_reduction_group = amax_reduction_group


def _get_normalization_func(normalization: str, forward: bool):
fwd_normalization_funcs = {
"LayerNorm": tex.layernorm_fwd,
Expand Down
12 changes: 11 additions & 1 deletion transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,12 @@ def get_ub(name: str, use_fp8: bool):
return _ub_communicators[key]


@torch.compiler.assume_constant_result
def get_ub_is_fp8(name: str, use_fp8: bool) -> bool:
"""Query is_fp8_ubuf for a named UB communicator; treated as compile-time constant."""
return get_ub(name, use_fp8).is_fp8_ubuf()


def destroy_ub():
"""Destroy all allocated userbuffer communicators."""
global _ub_communicators, _ub_with_cublasmp, _ub_initialized
Expand All @@ -562,6 +568,9 @@ def destroy_ub():
_ub_initialized = False
global layers_atomic_ring_exchange
layers_atomic_ring_exchange = []
# Compiled graphs may have baked is_fp8_ubuf() via assume_constant_result;
# reset so re-init with different settings doesn't read stale constants.
torch.compiler.reset()


def fill_userbuffers_buffer_for_all_gather(
Expand Down Expand Up @@ -1049,7 +1058,8 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
if recipe.nvfp4() and isinstance(recipe_state, NVFP4BlockScalingRecipeState):
return
if recipe.custom() and isinstance(recipe_state, CustomRecipeState):
return
if recipe_state.recipe is recipe:
return

# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
# 2 (grad_output and grad_input) for bwd
Expand Down
78 changes: 27 additions & 51 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .base import (
fill_userbuffers_buffer_for_all_gather,
get_ub,
get_ub_is_fp8,
is_ub_initialized,
using_cublasmp_backend,
quantize_weight,
Expand Down Expand Up @@ -58,7 +59,12 @@
from ..constants import FP8BwdTensorIdx, FP8FwdTensorIdx, GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ._common import apply_normalization, noop_cat, WeightGradStore
from ._common import (
apply_normalization,
noop_cat,
set_quantizer_amax_reduction_group,
WeightGradStore,
)
from ..quantized_tensor import (
QuantizedTensor,
QuantizedTensorStorage,
Expand Down Expand Up @@ -215,6 +221,11 @@ def forward(
if with_input_all_gather and input_quantizer.supports_only_rowwise_all_gather():
# All-gather is not supported with FP8 column-wise data
input_quantizer.set_usage(columnwise=False)
# Amax reduction group for the input quantizer (column-parallel sequence parallel)
set_quantizer_amax_reduction_group(
input_quantizer,
tp_group if (sequence_parallel and parallel_mode == "column") else None,
)

# Avoid quantized norm kernel if norm output will be returned
# or if a gather of ln_out must be in high precision.
Expand Down Expand Up @@ -690,6 +701,15 @@ def backward(
# tensor usage at a time. Configure quantizer with
# usage for only dgrad GEMM.
quantizer.set_usage(columnwise=False)
# Amax reduction group for grad output (row-parallel sequence parallel)
set_quantizer_amax_reduction_group(
quantizer,
(
ctx.tp_group
if (ctx.sequence_parallel and ctx.parallel_mode == "row")
else None
),
)

# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
Expand Down Expand Up @@ -1051,8 +1071,10 @@ def wgrad_gemm(
if ctx.ln_out_needs_gather:
# Gathered input is internal
clear_tensor_data(ln_out_total)
if ctx.parallel_mode == "row" and ctx.sequence_parallel:
# Gathered grad output tensor is internal
if ctx.sequence_parallel and (
ctx.parallel_mode == "row" or (ctx.parallel_mode == "column" and ctx.fp8)
):
# Gathered (row-SP) or quantized (column-SP FP8) grad_output is internal
clear_tensor_data(grad_output)

# Update grad input if overlapping reduce-scatter with wgrad GEMM
Expand Down Expand Up @@ -1552,8 +1574,6 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.nvfp4():
self._customize_quantizers_nvfp4(fwd, recipe)

def get_quantizer_roles(
self,
Expand Down Expand Up @@ -1668,14 +1688,10 @@ def forward(
is_first_microbatch = False

if self.ub_overlap_rs_fprop:
if get_ub(
self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled()
).is_fp8_ubuf():
if get_ub_is_fp8(self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled()):
fp8_output = True
if self.ub_overlap_rs_dgrad:
if get_ub(
self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled()
).is_fp8_ubuf():
if get_ub_is_fp8(self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled()):
fp8_grad = True

inp = self.prepare_forward(
Expand Down Expand Up @@ -1919,15 +1935,6 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM1_WEIGHT
].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon
# parallel related
if self.sequence_parallel and self.parallel_mode == "column":
# set input_quantizer with amax reduction TP group
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM1_INPUT
].with_amax_reduction = True
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM1_INPUT
].amax_reduction_group = self.tp_group
else:
# set grad_output_quantizer with amax epsilon and power_2_scale (no amax reduction here)
self.quantizers["scaling_bwd"][
Expand All @@ -1936,37 +1943,6 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT1
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
# parallel related
if self.sequence_parallel and self.parallel_mode == "row":
# customize grad_output_quantizer with amax reduction TP group
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT1
].with_amax_reduction = True
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group

def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + layernorm_linear."""
assert recipe.nvfp4(), "Incorrect recipe."
if fwd:
if self.sequence_parallel and self.parallel_mode == "column":
# set input_quantizer with amax reduction TP group
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM1_INPUT
].with_amax_reduction = True
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM1_INPUT
].amax_reduction_group = self.tp_group
else:
if self.sequence_parallel and self.parallel_mode == "row":
# customize grad_output_quantizer with amax reduction TP group
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT1
].with_amax_reduction = True
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group

def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]:
"""Get the weight tensors of the module."""
Expand Down
56 changes: 13 additions & 43 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
fill_userbuffers_buffer_for_all_gather,
_ub_communicators,
get_ub,
get_ub_is_fp8,
is_ub_initialized,
using_cublasmp_backend,
quantize_weight,
Expand Down Expand Up @@ -73,7 +74,7 @@
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.nvfp4_tensor import NVFP4Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ._common import apply_normalization, WeightGradStore
from ._common import apply_normalization, set_quantizer_amax_reduction_group, WeightGradStore
from ..cpu_offload import (
is_cpu_offload_enabled,
start_offload,
Expand Down Expand Up @@ -399,6 +400,11 @@ def _forward(
if sequence_parallel and fc1_input_quantizer.supports_only_rowwise_all_gather():
# All-gather is not supported with FP8 column-wise data
fc1_input_quantizer.set_usage(columnwise=False)
# Amax reduction group for the FC1 input quantizer (column-parallel sequence parallel)
set_quantizer_amax_reduction_group(
fc1_input_quantizer,
tp_group if (sequence_parallel and set_parallel_mode) else None,
)

# for fp8 DelayedScaling: layernorm output = FP8
# only output of the linear is returned
Expand Down Expand Up @@ -1138,6 +1144,11 @@ def backward(
# tensor usage at a time. Configure quantizer with
# usage for only dgrad GEMM.
quantizer.set_usage(columnwise=False)
# Amax reduction group for FC2 grad output (row-parallel sequence parallel)
set_quantizer_amax_reduction_group(
quantizer,
ctx.tp_group if (ctx.sequence_parallel and ctx.set_parallel_mode) else None,
)

# Prepare FC2 grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
Expand Down Expand Up @@ -2165,8 +2176,6 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.nvfp4():
self._customize_quantizers_nvfp4(fwd, recipe)

def get_quantizer_roles(
self,
Expand Down Expand Up @@ -2292,7 +2301,7 @@ def forward(

fp8_output = False
if self.ub_overlap_rs:
if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf():
if get_ub_is_fp8("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()):
fp8_output = True

inp = self.prepare_forward(inp, num_gemms=2)
Expand Down Expand Up @@ -2676,15 +2685,6 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM2_WEIGHT
].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon
# parallel related
if self.sequence_parallel and self.set_parallel_mode:
# fc1_input_quantizer: customize input_quantizer with amax reduction TP group, column parallel + sequence parallel here
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM1_INPUT
].with_amax_reduction = True
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM1_INPUT
].amax_reduction_group = self.tp_group
else:
# fc2_grad_output_quantizer: set configs about amax epsilon and power_2_scale for fc2_grad_output_quantizer
self.quantizers["scaling_bwd"][
Expand All @@ -2700,36 +2700,6 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT1
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
if self.sequence_parallel and self.set_parallel_mode:
# fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT2
].with_amax_reduction = True
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT2
].amax_reduction_group = self.tp_group

def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + layernorm_mlp."""
assert recipe.nvfp4(), "Incorrect recipe."
if fwd:
if self.sequence_parallel and self.set_parallel_mode:
# fc1_input_quantizer: customize input_quantizer with amax reduction TP group, column parallel + sequence parallel here
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM1_INPUT
].with_amax_reduction = True
self.quantizers["scaling_fwd"][
FP8FwdTensorIdx.GEMM1_INPUT
].amax_reduction_group = self.tp_group
else:
if self.sequence_parallel and self.set_parallel_mode:
# fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT2
].with_amax_reduction = True
self.quantizers["scaling_bwd"][
FP8BwdTensorIdx.GRAD_OUTPUT2
].amax_reduction_group = self.tp_group

def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]:
"""Get the weight tensors of the module."""
Expand Down
Loading
Loading