[PyTorch][torch.compile] Decouple amax reduction group from the quantizer#3104
[PyTorch][torch.compile] Decouple amax reduction group from the quantizer#3104pggPL wants to merge 7 commits into
Conversation
|
/te-ci pytorch L1 |
Greptile SummaryThis PR decouples the per-tensor amax reduction process group from the quantizer to fix non-picklable process-group leakage onto
Confidence Score: 5/5Safe to merge; the refactor correctly decouples process-group lifetime from quantizer objects and the FSDP2/TP paths are logically equivalent to the old setup-time wiring. All amax reduction group assignments have been verified to apply the correct groups at the correct call sites, the throwaway-copy pattern prevents group leakage to output tensor quantizers, and reduce_ex correctly excludes the new tensor-level amax_reduction_group attribute so pickling is fixed. The two remaining concerns (implicit module-quantizer mutation via aliased _quantizer and layernorm_mlp still reading fp8_recipe at backward time) are quality improvements rather than correctness blockers for the normal training path. transformer_engine/pytorch/quantized_tensor.py (stripping-via-alias assumption) and transformer_engine/pytorch/module/layernorm_mlp.py (fp8_recipe access at backward time). Important Files Changed
Sequence Diagram%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
participant M as Module
participant Q as Quantizer
participant F as forward_impl
participant B as backward_impl
participant FT as Float8Tensor weight
participant FSDP as FSDP2
Note over M,B: TP sequence-parallel path
M->>F: forward
F->>Q: set_quantizer_amax_reduction_group(tp_group)
F->>Q: quantize input - allreduce amax then cast
Q->>Q: strip with_amax_reduction from result._quantizer
F-->>B: save context with precomputed split-accum flags
B->>Q: set_quantizer_amax_reduction_group(tp_group)
B->>Q: re-quantize input for wgrad - allreduce amax
Q->>Q: strip with_amax_reduction
Note over FT,FSDP: FSDP2 weight path
FSDP->>FT: fsdp_pre_all_gather - set amax_reduction_group on tensor
FSDP->>FT: all-gather reconstructed weight used in compute
FSDP->>FT: optimizer step triggers _set_data
FT->>FT: create throwaway quantizer copy with group
FT->>FT: quantize via throwaway - allreduce amax
FT->>FT: result quantizer stripped - group never leaks
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
participant M as Module
participant Q as Quantizer
participant F as forward_impl
participant B as backward_impl
participant FT as Float8Tensor weight
participant FSDP as FSDP2
Note over M,B: TP sequence-parallel path
M->>F: forward
F->>Q: set_quantizer_amax_reduction_group(tp_group)
F->>Q: quantize input - allreduce amax then cast
Q->>Q: strip with_amax_reduction from result._quantizer
F-->>B: save context with precomputed split-accum flags
B->>Q: set_quantizer_amax_reduction_group(tp_group)
B->>Q: re-quantize input for wgrad - allreduce amax
Q->>Q: strip with_amax_reduction
Note over FT,FSDP: FSDP2 weight path
FSDP->>FT: fsdp_pre_all_gather - set amax_reduction_group on tensor
FSDP->>FT: all-gather reconstructed weight used in compute
FSDP->>FT: optimizer step triggers _set_data
FT->>FT: create throwaway quantizer copy with group
FT->>FT: quantize via throwaway - allreduce amax
FT->>FT: result quantizer stripped - group never leaks
Reviews (9): Last reviewed commit: "Unwrap DebugQuantizer when setting amax ..." | Re-trigger Greptile |
|
Blocked by FSDP bug, refactor in progress. I plan to store .amax_reduction_group in QuantizedTensor. |
There was a problem hiding this comment.
This would be a design mistake. The amax reduction does not have a consistent meaning across recipes (including recipes where it doesn't make sense), and this change requires spilling out amax reduction logic into quantizer callsites (even where it doesn't make sense).
Can you go into more detail exactly why torch.compile doesn't work when quantizers have process groups? If we just want the quantizer to hold simple Python objects, maybe we can make the quantizer hold an int for the communicator ID. I envision something like:
class Float8CurrentScalingQuantizer(Quantizer):
_communicator_cache = {}
@property
def amax_reduction_group(self):
if self._amax_reduction_group_id is None:
return None
return Float8CurrentScalingQuantizer._communicator_cache[self._amax_reduction_group_id]
@property.setter
def amax_reduction_group(self, comm):
if comm is None:
self._amax_reduction_group_id = None
self._amax_reduction_group_id = id(comm)
Float8CurrentScalingQuantizer._communicator_cache[self._amax_reduction_group_id] = commI'm not sure how this would interact with checkpointing though.
| dst: QuantizedTensor, | ||
| *, | ||
| noop_flag: Optional[torch.Tensor] = None, | ||
| amax_reduction_group: Optional[Any] = None, # pylint: disable=unused-argument |
There was a problem hiding this comment.
I strongly oppose this API change. amax reduction is very recipe-specific. It has different meanings for different recipes (FP8 DS might reduce over the TP+DP group, FP8 CS might only reduce over the TP group) and it has no meaning for other recipes (MXFP8 and FP8 block scaling). Moving it into the generic API will leak recipe-specific information, defeating the point of a generic API.
|
/te-ci pytorch L1 |
…stants; fix SP memory leak; test suite hook-up Wrap CommOverlapCore pybind11 methods that return compile-time constants so torch.compile(fullgraph=True) can trace through them without graph breaks: - `is_fp8_ubuf()` → `ub_is_fp8()` / `get_ub_is_fp8()` in base.py; `_ub_is_fp8()` in gemm.py - `with_cublasmp()` → `ub_is_cublasmp()` in base.py All callers in linear.py, layernorm_linear.py, layernorm_mlp.py, base.py, gemm.py, userbuffers_backward_linear.py and userbuffers_forward_linear.py updated. Fix quantized grad_output not being freed early for column-parallel SP backward. Row-parallel SP already called clear_tensor_data(grad_output) to release the gathered tensor; column-parallel SP quantizes grad_output to Float8TensorStorage but never freed it before returning. Under torch.compile reduce-overhead this leaves 3 live pool tensors at recording end and triggers "Detected 3 tensor(s) in the cudagraph pool not tracked as outputs". Extend the existing clear_tensor_data guard to cover both parallel modes. Fix custom-recipe quantizer state being re-initialised on every forward call even when the recipe object has not changed. The existing early-exit for CustomRecipeState was missing an identity check on the recipe object, so any repeated call with the same recipe would bypass the early-return and rebuild quantizers unnecessarily. Add `if recipe_state.recipe is recipe: return` to restore the intended caching behaviour. Add test_torch_compile.py to L0_pytorch_unittest so the autocast and existing compile tests run in CI. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…-accumulator booleans LinearBwdArgs stored the entire FP8 recipe object so the backward could extract fp8_gemm_dgrad.use_split_accumulator and fp8_gemm_wgrad.use_split_accumulator at GEMM time. Recipe objects hold process-group references and are not serialisable as compile-time constants, making them incompatible with torch.compile custom-op paths. Replace fp8_recipe with two plain bool fields: - dgrad_use_split_accumulator (default _2X_ACC_DGRAD) - wgrad_use_split_accumulator (default _2X_ACC_WGRAD) These are resolved once in _linear_setup_ctx and passed into the args struct, so the backward consumes scalars instead of a live recipe object. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
…t_result get_ub_is_fp8 bakes is_fp8_ubuf() as a compile-time constant; without a reset, destroy_ub + re-init with different FP8 settings would read stale values until recompile. Only affects in-memory caches, not disk. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
e9097d6 to
948cd6d
Compare
ToyLinear now overrides get_quantizer_roles so CustomRecipeState doesn't hit the no-roles warning, which graph-breaks under fullgraph=True. qfactory dispatches on role.tensor_type instead of a pre-baked string key. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
The amax reduction process group is no longer stored persistently on a module quantizer or on a tensor's quantizer. No C++ changes. - TP sequence parallel: the group is set on the input/grad-output quantizer at point of use in the fwd/bwd impls (linear, layernorm_linear, layernorm_mlp, ops basic_linear), replacing the setup-time _customize_quantizers wiring. - FSDP2: the group is stored on Float8Tensor/NVFP4Tensor (set in fsdp_pre_all_gather) and applied to a throwaway quantizer copy during the in-place re-quant (update_quantized / _set_data). - quantize() strips the group off the output tensor's quantizer so it never persists on any tensor's quantizer (breaks flatten/pickle otherwise). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
b8c1bec to
6c9b986
Compare
|
/te-ci pytorch L1 |
set_quantizer_amax_reduction_group was a no-op on a DebugQuantizer (it lacks with_amax_reduction), so with nvinspect enabled the parent quantizer never got the SP amax reduction group, breaking fp8 current scaling column-parallel sequence-parallel numerics (debug test_numerics). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
|
@timmoon10 I have changed the concept of this PR. Note that it also consist the changed from #3130 |
|
/te-ci pytorch L1 |
Description
There are 2 problems with amax_reduction_group on Quantizers:
Also, current design is prone to bugs. Amax reduction group is set in module forward(). So if we have forward for tp_group1, then forward for tp_group2, this second forwad overrides the amax_reduction_group of quantizers which are used in both backwards. So I think it is better design to set amax reduction group in forward and backward directly.
We may need to slightly refactor the quantizer to mitigate this kinds of bugs, but for now the change in this PR will be sufficient for torch.compile support.
Fixes # (N/A)
Type of change
Changes
quantizer at point of use in the fwd/bwd impls (
linear,layernorm_linear,layernorm_mlp,ops/basic_linear) via a newset_quantizer_amax_reduction_grouphelper, instead of once at module setup.Removed the group wiring from
_customize_quantizers_float8_current_scalingand dropped
_customize_quantizers_nvfp4entirely.Float8Tensor/NVFP4Tensor(
amax_reduction_groupattribute, set infsdp_pre_all_gather) and apply itto a throwaway quantizer copy during the in-place re-quant
(
update_quantized/_set_data) — the weight's own quantizer is never mutated.Quantizer.quantize()strips the group off the output tensor's quantizer, soit never persists on any tensor's
_quantizer.Checklist: