Skip to content

[PyTorch][torch.compile] Decouple amax reduction group from the quantizer#3104

Open
pggPL wants to merge 7 commits into
NVIDIA:mainfrom
pggPL:remove_process_group_from_quantizers
Open

[PyTorch][torch.compile] Decouple amax reduction group from the quantizer#3104
pggPL wants to merge 7 commits into
NVIDIA:mainfrom
pggPL:remove_process_group_from_quantizers

Conversation

@pggPL

@pggPL pggPL commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator

Description

There are 2 problems with amax_reduction_group on Quantizers:

  • I want to declare quantizers as opaque value objects which is some kind of Python constant - the ProcessGroup and tensor inside are problematic,
  • there is PyTorch for dealing with custom tensor classes like Float8Tensor in torch.compile: tensor_flatten() and tensor_unflatten() which assume that all internal tensors or opaque reference objects like process groups are directly parameters of a tensor. Currently, they are parameters of parameter (quantizer). I change that in this PR - amax reduction group is also stored on a QuantizedTensor when applicable.

Also, current design is prone to bugs. Amax reduction group is set in module forward(). So if we have forward for tp_group1, then forward for tp_group2, this second forwad overrides the amax_reduction_group of quantizers which are used in both backwards. So I think it is better design to set amax reduction group in forward and backward directly.
We may need to slightly refactor the quantizer to mitigate this kinds of bugs, but for now the change in this PR will be sufficient for torch.compile support.

Fixes # (N/A)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • TP sequence parallel: set the amax reduction group on the input / grad-output
    quantizer at point of use in the fwd/bwd impls (linear, layernorm_linear,
    layernorm_mlp, ops/basic_linear) via a new
    set_quantizer_amax_reduction_group helper, instead of once at module setup.
    Removed the group wiring from _customize_quantizers_float8_current_scaling
    and dropped _customize_quantizers_nvfp4 entirely.
  • FSDP2: store the group on Float8Tensor / NVFP4Tensor
    (amax_reduction_group attribute, set in fsdp_pre_all_gather) and apply it
    to a throwaway quantizer copy during the in-place re-quant
    (update_quantized / _set_data) — the weight's own quantizer is never mutated.
  • Quantizer.quantize() strips the group off the output tensor's quantizer, so
    it never persists on any tensor's _quantizer.
  • No C++ changes.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@pggPL

pggPL commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

@greptile-apps

greptile-apps Bot commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR decouples the per-tensor amax reduction process group from the quantizer to fix non-picklable process-group leakage onto Float8Tensor._quantizer, which blocked FSDP2 flatten/checkpointing. The group is now applied transiently at point-of-use (TP sequence-parallel) or carried on the tensor itself (FSDP2 weight), and Quantizer.quantize() strips it from the result tensor's _quantizer after each use.

  • TP sequence-parallel: set_quantizer_amax_reduction_group is called at the start of each forward/backward impl instead of once at module-setup time; the three separate _customize_quantizers_nvfp4 methods are removed and unified.
  • FSDP2: Float8Tensor / NVFP4Tensor gain an amax_reduction_group instance attribute set by fsdp_pre_all_gather; update_quantized / _set_data apply it via a throwaway quantizer copy.
  • Auxiliary fixes: LinearBwdArgs.fp8_recipe replaced with precomputed split-accumulator booleans; get_ub_is_fp8 cached with assume_constant_result; CustomRecipeState recipe identity check to handle recipe-object churn in tests.

Confidence Score: 5/5

Safe to merge; the refactor correctly decouples process-group lifetime from quantizer objects and the FSDP2/TP paths are logically equivalent to the old setup-time wiring.

All amax reduction group assignments have been verified to apply the correct groups at the correct call sites, the throwaway-copy pattern prevents group leakage to output tensor quantizers, and reduce_ex correctly excludes the new tensor-level amax_reduction_group attribute so pickling is fixed. The two remaining concerns (implicit module-quantizer mutation via aliased _quantizer and layernorm_mlp still reading fp8_recipe at backward time) are quality improvements rather than correctness blockers for the normal training path.

transformer_engine/pytorch/quantized_tensor.py (stripping-via-alias assumption) and transformer_engine/pytorch/module/layernorm_mlp.py (fp8_recipe access at backward time).

Important Files Changed

Filename Overview
transformer_engine/pytorch/quantized_tensor.py Adds post-quantize stripping of amax_reduction_group from result tensor's _quantizer; for standard quantizers result._quantizer aliases the module quantizer so this also implicitly clears the module quantizer after each use (intended pop behaviour).
transformer_engine/pytorch/tensor/float8_tensor.py FSDP2 path now stores amax_reduction_group as a Float8Tensor instance attribute (excluded from reduce_ex); _set_data applies it via a throwaway quantizer copy; fsdp_pre_all_gather no longer mutates the module's quantizer object.
transformer_engine/pytorch/tensor/nvfp4_tensor.py Mirrors Float8Tensor FSDP2 changes; amax_reduction_group stored on NVFP4Tensor instance and applied in update_quantized via throwaway copy.
transformer_engine/pytorch/module/_common.py New set_quantizer_amax_reduction_group helper centralizes the group wiring; correctly unwraps DebugQuantizer to parent_quantizer.
transformer_engine/pytorch/module/linear.py Replaces fp8_recipe in LinearBwdArgs with precomputed split-accumulator booleans (compile-friendly); set_quantizer_amax_reduction_group called at forward/backward time instead of module-setup; clear_tensor_data extended to column-SP FP8 grad_output.
transformer_engine/pytorch/module/layernorm_linear.py Removes _customize_quantizers_nvfp4; sets amax_reduction_group at point-of-use in forward/backward; clear_tensor_data extended for column-SP FP8 grad_output.
transformer_engine/pytorch/module/layernorm_mlp.py Removes _customize_quantizers_nvfp4; sets fc1_input and fc2_grad_output groups at point-of-use; reads split-accum from saved ctx.fp8_recipe (still recipe-object access, unlike linear.py).
transformer_engine/pytorch/module/base.py Adds get_ub_is_fp8 with assume_constant_result for compile-graph stability; destroy_ub now calls torch.compiler.reset() to invalidate stale baked constants; CustomRecipeState identity check added to avoid skipping reinit on recipe object churn.
transformer_engine/pytorch/ops/basic/basic_linear.py Removes _customize_quantizers_nvfp4/float8 setup-time group wiring; replaces with three set_quantizer_amax_reduction_group calls at forward/backward sites.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant M as Module
    participant Q as Quantizer
    participant F as forward_impl
    participant B as backward_impl
    participant FT as Float8Tensor weight
    participant FSDP as FSDP2

    Note over M,B: TP sequence-parallel path
    M->>F: forward
    F->>Q: set_quantizer_amax_reduction_group(tp_group)
    F->>Q: quantize input - allreduce amax then cast
    Q->>Q: strip with_amax_reduction from result._quantizer
    F-->>B: save context with precomputed split-accum flags
    B->>Q: set_quantizer_amax_reduction_group(tp_group)
    B->>Q: re-quantize input for wgrad - allreduce amax
    Q->>Q: strip with_amax_reduction

    Note over FT,FSDP: FSDP2 weight path
    FSDP->>FT: fsdp_pre_all_gather - set amax_reduction_group on tensor
    FSDP->>FT: all-gather reconstructed weight used in compute
    FSDP->>FT: optimizer step triggers _set_data
    FT->>FT: create throwaway quantizer copy with group
    FT->>FT: quantize via throwaway - allreduce amax
    FT->>FT: result quantizer stripped - group never leaks
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant M as Module
    participant Q as Quantizer
    participant F as forward_impl
    participant B as backward_impl
    participant FT as Float8Tensor weight
    participant FSDP as FSDP2

    Note over M,B: TP sequence-parallel path
    M->>F: forward
    F->>Q: set_quantizer_amax_reduction_group(tp_group)
    F->>Q: quantize input - allreduce amax then cast
    Q->>Q: strip with_amax_reduction from result._quantizer
    F-->>B: save context with precomputed split-accum flags
    B->>Q: set_quantizer_amax_reduction_group(tp_group)
    B->>Q: re-quantize input for wgrad - allreduce amax
    Q->>Q: strip with_amax_reduction

    Note over FT,FSDP: FSDP2 weight path
    FSDP->>FT: fsdp_pre_all_gather - set amax_reduction_group on tensor
    FSDP->>FT: all-gather reconstructed weight used in compute
    FSDP->>FT: optimizer step triggers _set_data
    FT->>FT: create throwaway quantizer copy with group
    FT->>FT: quantize via throwaway - allreduce amax
    FT->>FT: result quantizer stripped - group never leaks
Loading

Reviews (9): Last reviewed commit: "Unwrap DebugQuantizer when setting amax ..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/quantized_tensor.py Outdated
Comment thread transformer_engine/pytorch/tensor/nvfp4_tensor.py Outdated
@pggPL

pggPL commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator Author

Blocked by FSDP bug, refactor in progress.

I plan to store .amax_reduction_group in QuantizedTensor.

@timmoon10 timmoon10 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be a design mistake. The amax reduction does not have a consistent meaning across recipes (including recipes where it doesn't make sense), and this change requires spilling out amax reduction logic into quantizer callsites (even where it doesn't make sense).

Can you go into more detail exactly why torch.compile doesn't work when quantizers have process groups? If we just want the quantizer to hold simple Python objects, maybe we can make the quantizer hold an int for the communicator ID. I envision something like:

class Float8CurrentScalingQuantizer(Quantizer):

    _communicator_cache = {}

    @property
    def amax_reduction_group(self):
        if self._amax_reduction_group_id is None:
            return None
        return Float8CurrentScalingQuantizer._communicator_cache[self._amax_reduction_group_id]

    @property.setter
    def amax_reduction_group(self, comm):
        if comm is None:
            self._amax_reduction_group_id = None
        self._amax_reduction_group_id = id(comm)
        Float8CurrentScalingQuantizer._communicator_cache[self._amax_reduction_group_id] = comm

I'm not sure how this would interact with checkpointing though.

dst: QuantizedTensor,
*,
noop_flag: Optional[torch.Tensor] = None,
amax_reduction_group: Optional[Any] = None, # pylint: disable=unused-argument

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I strongly oppose this API change. amax reduction is very recipe-specific. It has different meanings for different recipes (FP8 DS might reduce over the TP+DP group, FP8 CS might only reduce over the TP group) and it has no meaning for other recipes (MXFP8 and FP8 block scaling). Moving it into the generic API will leak recipe-specific information, defeating the point of a generic API.

@pggPL

pggPL commented Jun 10, 2026

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

pggPL and others added 4 commits June 15, 2026 16:40
…stants; fix SP memory leak; test suite hook-up

Wrap CommOverlapCore pybind11 methods that return compile-time constants
so torch.compile(fullgraph=True) can trace through them without graph
breaks:
- `is_fp8_ubuf()` → `ub_is_fp8()` / `get_ub_is_fp8()` in base.py;
  `_ub_is_fp8()` in gemm.py
- `with_cublasmp()` → `ub_is_cublasmp()` in base.py

All callers in linear.py, layernorm_linear.py, layernorm_mlp.py,
base.py, gemm.py, userbuffers_backward_linear.py and
userbuffers_forward_linear.py updated.

Fix quantized grad_output not being freed early for column-parallel SP
backward. Row-parallel SP already called clear_tensor_data(grad_output)
to release the gathered tensor; column-parallel SP quantizes grad_output
to Float8TensorStorage but never freed it before returning.  Under
torch.compile reduce-overhead this leaves 3 live pool tensors at
recording end and triggers "Detected 3 tensor(s) in the cudagraph pool
not tracked as outputs".  Extend the existing clear_tensor_data guard to
cover both parallel modes.

Fix custom-recipe quantizer state being re-initialised on every forward
call even when the recipe object has not changed. The existing early-exit
for CustomRecipeState was missing an identity check on the recipe object,
so any repeated call with the same recipe would bypass the early-return
and rebuild quantizers unnecessarily.  Add `if recipe_state.recipe is
recipe: return` to restore the intended caching behaviour.

Add test_torch_compile.py to L0_pytorch_unittest so the autocast and
existing compile tests run in CI.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…-accumulator booleans

LinearBwdArgs stored the entire FP8 recipe object so the backward could
extract fp8_gemm_dgrad.use_split_accumulator and
fp8_gemm_wgrad.use_split_accumulator at GEMM time.  Recipe objects hold
process-group references and are not serialisable as compile-time
constants, making them incompatible with torch.compile custom-op paths.

Replace fp8_recipe with two plain bool fields:
- dgrad_use_split_accumulator (default _2X_ACC_DGRAD)
- wgrad_use_split_accumulator (default _2X_ACC_WGRAD)

These are resolved once in _linear_setup_ctx and passed into the args
struct, so the backward consumes scalars instead of a live recipe object.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…t_result

get_ub_is_fp8 bakes is_fp8_ubuf() as a compile-time constant; without a
reset, destroy_ub + re-init with different FP8 settings would read stale
values until recompile. Only affects in-memory caches, not disk.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL force-pushed the remove_process_group_from_quantizers branch from e9097d6 to 948cd6d Compare June 16, 2026 12:23
pggPL and others added 2 commits June 16, 2026 16:32
ToyLinear now overrides get_quantizer_roles so CustomRecipeState doesn't hit
the no-roles warning, which graph-breaks under fullgraph=True. qfactory
dispatches on role.tensor_type instead of a pre-baked string key.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
The amax reduction process group is no longer stored persistently on a module
quantizer or on a tensor's quantizer. No C++ changes.

- TP sequence parallel: the group is set on the input/grad-output quantizer at
  point of use in the fwd/bwd impls (linear, layernorm_linear, layernorm_mlp,
  ops basic_linear), replacing the setup-time _customize_quantizers wiring.
- FSDP2: the group is stored on Float8Tensor/NVFP4Tensor (set in
  fsdp_pre_all_gather) and applied to a throwaway quantizer copy during the
  in-place re-quant (update_quantized / _set_data).
- quantize() strips the group off the output tensor's quantizer so it never
  persists on any tensor's quantizer (breaks flatten/pickle otherwise).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL force-pushed the remove_process_group_from_quantizers branch from b8c1bec to 6c9b986 Compare June 16, 2026 14:56
@pggPL

pggPL commented Jun 17, 2026

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

@pggPL pggPL changed the title [PyTorch][torch.compile] Remove process group from quantizers [PyTorch][torch.compile] Decouple amax reduction group from the quantizer Jun 17, 2026
set_quantizer_amax_reduction_group was a no-op on a DebugQuantizer (it
lacks with_amax_reduction), so with nvinspect enabled the parent
quantizer never got the SP amax reduction group, breaking fp8 current
scaling column-parallel sequence-parallel numerics (debug test_numerics).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL

pggPL commented Jun 17, 2026

Copy link
Copy Markdown
Collaborator Author

@timmoon10 I have changed the concept of this PR. Note that it also consist the changed from #3130

@pggPL

pggPL commented Jun 17, 2026

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants