Skip to content

Abstract CUDA hardcodes into configurable te_device_type / te_platform#3113

Open
lxd-cumt wants to merge 5 commits into
NVIDIA:release_v2.14from
lxd-cumt:cuda_patch
Open

Abstract CUDA hardcodes into configurable te_device_type / te_platform#3113
lxd-cumt wants to merge 5 commits into
NVIDIA:release_v2.14from
lxd-cumt:cuda_patch

Conversation

@lxd-cumt

@lxd-cumt lxd-cumt commented Jun 10, 2026

Copy link
Copy Markdown

FlagOS Proposal: Plugin Architecture & Device-Agnostic Abstraction for TransformerEngine

Device-Type Abstraction: Replacing Hardcoded "cuda" References

The current TE PyTorch layer contains ~100 hardcoded "cuda" string literals and ~165 torch.cuda.* API calls. These span device placement (device="cuda"), autocast context (device_type="cuda"), device-type guards (device.type == "cuda"), and RNG state management (torch.cuda.CUDAGraph, torch.cuda._lazy_call). This makes TE non-functional on alternative accelerator platforms without invasive patching.

Proposed Design

  1. Soft abstraction – A global te_device_type() / te_platform() accessor replaces ~200 literal "cuda" strings across the Python codebase.

  2. Platform monkey-patch – A vendor-provided apply_patch() hook runs at import time to directly remap torch.cuda.* APIs (e.g. torch.cuda.device, torch.cuda.current_device, torch.cuda.current_stream) to the vendor equivalents (e.g. torch.other_vendor.*).

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 10, 2026
@greptile-apps

greptile-apps Bot commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR introduces a soft device-type abstraction layer for TransformerEngine's PyTorch backend, replacing ~200 hardcoded "cuda" string literals across 45 files with calls to te_device_type() / te_platform() global accessors, and adds a plugin hook (NVTE_ENABLE_PLUGIN=1) that lets a vendor-provided patch remap both the device-type string and torch.cuda.* API calls at import time. Remaining torch.cuda.* API calls (RNG state, streams, generators) are intentionally left to be handled by the monkey-patch path described in the PR.

  • New global accessorste_device_type() and te_platform() are added to transformer_engine/__init__.py along with a guarded plugin-loader that now emits a RuntimeWarning on failure instead of swallowing the error silently.
  • Systematic string replacement – Device strings and device_type= arguments for autocast, tensor constructors, device guards, and quantizer defaults across the entire pytorch layer are updated to use te_device_type().
  • Bundled attention fixes – A flash-attn 2.3.x–2.6.x rng_state dropout fix in context parallel and a new determinism guard that disables FusedAttention on Blackwell (SM ≥ 10.0) are included alongside the abstraction changes.

Confidence Score: 4/5

Safe to merge on standard CUDA deployments; the MUSA plugin path has several minor inconsistencies that could silently disable optimisations or surface incorrect error messages on alternative backends.

All changes on the standard CUDA path are mechanical string substitutions that evaluate to the same value, so no existing behaviour changes. The MUSA plugin path has a dead-code ternary in Float8BlockQuantizer.make_empty, a double-space error message, a potential LRU-cache device confusion in get_rht_matrix, and two hard-coded torch.cuda.* guards that survive the device-type check in GroupedTensorStorage — none of these affect CUDA users, but they leave the MUSA path incompletely abstracted.

transformer_engine/pytorch/tensor/nvfp4_tensor.py (LRU cache + device move interaction), transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py (mixed device guards), transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py (double-space typo).

Important Files Changed

Filename Overview
transformer_engine/init.py Introduces te_device_type() / te_platform() accessors and the plugin loader; plugin failures now emit a RuntimeWarning rather than silently swallowing errors.
transformer_engine/pytorch/utils.py Module-level gpu_autocast_ctx captures te_device_type() once at import time; works correctly because the plugin runs before pytorch is imported, but is a subtle coupling to import order.
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py Device-type guard updated; error message has a double-space typo ({te_device_type().upper()} devices).
transformer_engine/pytorch/distributed.py _get_cuda_rng_state/_set_cuda_rng_state default parameter updated to te_device_type(); bodies still use torch.cuda.current_device() and torch.cuda.default_generators (intentionally covered by monkey-patch).
transformer_engine/pytorch/jit.py Tensor creation in warmup functions moved to te_device_type(); torch.cuda.get_rng_state()/set_rng_state()/empty_cache() calls intentionally left for monkey-patch coverage.
transformer_engine/pytorch/tensor/nvfp4_tensor.py Adds .to(te_device_type()) to get_rht_matrix return after tensor already lives on the given device; can silently move cached results away from the requested device index.
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py Float8BlockQuantizer.make_empty has dead code: the new if device is None guard sets device before the ternary still checks if device is None, making the true-branch unreachable.
transformer_engine/pytorch/attention/dot_product_attention/utils.py Attention backend additions: FusedAttention disabled for thd format with max_logit, and disabled for deterministic training on Blackwell (SM>=10.0); device string literals replaced.
transformer_engine/pytorch/module/base.py Device guards and tensor allocations use te_device_type(); the RuntimeError pattern for device availability check is preserved correctly.
transformer_engine/pytorch/permutation.py Systematic replacement of tensor.is_cuda checks with tensor.device.type == te_device_type(); changes are mechanical and consistent.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant Init as transformer_engine/__init__.py
    participant Plugin as transformer_engine_plugin_fl.patches
    participant PyTorch as transformer_engine/pytorch/*
    participant User as User Code

    Init->>Init: "TE_DEVICE_TYPE = "cuda""
    Init->>Init: "TE_PLATFORM = torch.cuda"
    alt "NVTE_ENABLE_PLUGIN=1"
        Init->>Plugin: apply_patches()
        Plugin->>Init: "TE_DEVICE_TYPE = "musa""
        Plugin->>Init: "TE_PLATFORM = torch.musa"
        Plugin->>Plugin: "monkey-patch torch.cuda.* → torch.musa.*"
    end
    Init->>PyTorch: from . import pytorch
    Note over PyTorch: Default args like device=te_device_type() captured here (after patch)
    PyTorch->>PyTorch: "gpu_autocast_ctx = functools.partial(..., device_type=te_device_type())"
    User->>PyTorch: "TransformerLayer(device=te_device_type())"
    PyTorch->>PyTorch: "torch.cuda.* calls resolved via monkey-patch or native CUDA"
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant Init as transformer_engine/__init__.py
    participant Plugin as transformer_engine_plugin_fl.patches
    participant PyTorch as transformer_engine/pytorch/*
    participant User as User Code

    Init->>Init: "TE_DEVICE_TYPE = "cuda""
    Init->>Init: "TE_PLATFORM = torch.cuda"
    alt "NVTE_ENABLE_PLUGIN=1"
        Init->>Plugin: apply_patches()
        Plugin->>Init: "TE_DEVICE_TYPE = "musa""
        Plugin->>Init: "TE_PLATFORM = torch.musa"
        Plugin->>Plugin: "monkey-patch torch.cuda.* → torch.musa.*"
    end
    Init->>PyTorch: from . import pytorch
    Note over PyTorch: Default args like device=te_device_type() captured here (after patch)
    PyTorch->>PyTorch: "gpu_autocast_ctx = functools.partial(..., device_type=te_device_type())"
    User->>PyTorch: "TransformerLayer(device=te_device_type())"
    PyTorch->>PyTorch: "torch.cuda.* calls resolved via monkey-patch or native CUDA"
Loading

Reviews (4): Last reviewed commit: "fix: patch cuda hardcodes for plugin com..." | Re-trigger Greptile

dpa_utils._original_get_attention_backend = dpa_utils.get_attention_backend
# Replace dpa_utils.get_attention_backend with tex.get_attention_backend
# This allows each backend (FlagOS, CUDA, Reference) to control its own backend selection
dpa_utils.get_attention_backend = tex.get_attention_backend

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 AttributeError at import time breaks the PyTorch module

tex is transformer_engine_torch (the C++ extension), which does not expose a get_attention_backend attribute. Accessing tex.get_attention_backend directly on line 75 (without a getattr guard) raises AttributeError the moment transformer_engine.pytorch is imported, making the entire PyTorch backend unusable on any standard CUDA installation. The previous line (69) correctly uses getattr(tex, "flash_attention", _FlashAttentionNative) with a fallback — the same pattern must be applied here, or the unconditional attribute access must be removed.

Comment on lines +15 to +18
try:
from transformer_engine_torch import bulk_overlap_ag_with_external_gemm
except ImportError:
bulk_overlap_ag_with_external_gemm = None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 NoneType call crash in backward pass when bulk_overlap_ag_with_external_gemm is unavailable

The import is now guarded (= None on failure), but line 435 calls bulk_overlap_ag_with_external_gemm(ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream) unconditionally. Whenever this code path is hit in a tensor-parallel row-overlap backward pass on a system where this symbol is absent, a TypeError: 'NoneType' object is not callable is raised at runtime rather than at import time. A guard like if bulk_overlap_ag_with_external_gemm is not None: (or raising a descriptive error earlier) is needed at the call site.

Comment thread transformer_engine/__init__.py Outdated
Comment on lines +20 to +25
try:
from .plugin.core.backends.vendor.musa.patches import apply_patch as _musa_apply_patch

_musa_apply_patch()
except Exception as e:
pass

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Silent except Exception: pass hides all MUSA patch failures

The bare except Exception as e: pass swallows every failure during the MUSA patch import — including AttributeError raised when torch.musa.* attributes referenced in _PATCH_CALLS don't exist on standard CUDA systems. The variable e is never logged or inspected. On CUDA systems this is the common path (no torch.musa), so every import of transformer_engine silently triggers and discards an exception. At minimum, emit a logging.debug or use a narrower exception type (e.g., ImportError) and let other errors propagate.

Comment on lines +14 to +26
# Patches: (parent_object, attribute_name, replacement_callable)
_PATCH_CALLS: list[tuple[object, str, Callable[..., object]]] = [
# We do not recommend replace is_available, due to its device-related behavior.
# (torch.cuda, "is_available", torch.musa.is_available),
(torch.cuda, "get_device_properties", torch.musa.get_device_properties),
(torch.cuda, "device", torch.musa.device),
(torch.cuda, "current_device", torch.musa.current_device),
(torch.cuda, "synchronize", torch.musa.synchronize),
(torch.cuda, "is_current_stream_capturing", torch.musa.is_current_stream_capturing),
# TODO: Add NVTX patches for MUSA.
# NVTX is CUDA-specific; make it a no-op on MUSA.
(torch.cuda.nvtx, "range_push", _noop),
(torch.cuda.nvtx, "range_pop", _noop),

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 _PATCH_CALLS accesses torch.musa.* at module load time

_PATCH_CALLS is a module-level list that dereferences torch.musa.get_device_properties, torch.musa.device, etc. when patches.py is imported. On any system without torch_musa, this raises AttributeError the moment the import is attempted. The caller in __init__.py wraps this in a blanket except Exception: pass, so it fails silently, but it means the module is broken by construction on non-MUSA hosts. Deferring the attribute lookups to inside apply_patch() (where hasattr(torch, "musa") is already checked) would make the module safe to import on all platforms.

Comment on lines 159 to 160
if device.type != te_device_type():
raise ValueError(f"Only CUDA devices are supported (got {device})")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The condition now guards against non-te_device_type() devices, but the error message still hard-codes "CUDA". On a MUSA system a user would see "Only CUDA devices are supported" when their device is actually valid, which is misleading.

Suggested change
if device.type != te_device_type():
raise ValueError(f"Only CUDA devices are supported (got {device})")
if device.type != te_device_type():
raise ValueError(f"Only {te_device_type().upper()} devices are supported (got {device})")

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines 186 to 187
if device.type != te_device_type():
raise ValueError(f"Only CUDA devices are supported (got {device})")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Same stale error message: the guard now checks te_device_type() but the message still says "CUDA".

Suggested change
if device.type != te_device_type():
raise ValueError(f"Only CUDA devices are supported (got {device})")
if device.type != te_device_type():
raise ValueError(f"Only {te_device_type().upper()} devices are supported (got {device})")

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +43 to +50
# Mark TE global device type for Python-side callers.
# IMPORTANT: do not import `transformer_engine` here, because TE's `__init__.py`
# imports this module to run patches and that would cause a circular import.
try:
import transformer_engine

transformer_engine.TE_DEVICE_TYPE = "musa"
transformer_engine.TE_PLATFORM = torch.musa

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The comment warns against importing transformer_engine here due to a circular-import risk, but the very next lines do exactly that. During transformer_engine/__init__.py execution Python's import cache returns the partially-initialized module, so TE_DEVICE_TYPE (set before the patch call) is reachable and the assignment works — but the comment creates a false sense of safety and the approach is still fragile if the import order ever changes.

Suggested change
# Mark TE global device type for Python-side callers.
# IMPORTANT: do not import `transformer_engine` here, because TE's `__init__.py`
# imports this module to run patches and that would cause a circular import.
try:
import transformer_engine
transformer_engine.TE_DEVICE_TYPE = "musa"
transformer_engine.TE_PLATFORM = torch.musa
# Mark TE global device type for Python-side callers.
# NOTE: importing `transformer_engine` here re-enters a partially-initialised module
# (its __init__.py is still running), but Python's import cache makes this safe as long
# as TE_DEVICE_TYPE is assigned before apply_patch() is called.
try:
import transformer_engine
transformer_engine.TE_DEVICE_TYPE = "musa"
transformer_engine.TE_PLATFORM = torch.musa

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment thread transformer_engine/pytorch/utils.py Outdated
Comment on lines +552 to +560
if device is None:
# Use default CUDA device
device = torch.get_default_device()
if device.type != "cuda":
device = torch.device("cuda", torch.cuda.current_device())
if device.type != te_device_type():
device = torch.device(te_device_type(), torch.cuda.current_device())
elif not isinstance(device, torch.device):
device = torch.device(device)
if device.type == "cuda" and device.index is None:
device = torch.device("cuda", torch.cuda.current_device())
if device.type == te_device_type() and device.index is None:
device = torch.device(te_device_type(), torch.cuda.current_device())

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 After monkey-patching, torch.cuda.current_device() is still called directly here. On MUSA hosts this works only because the patch replaces torch.cuda.current_device with torch.musa.current_device. Using te_platform().current_device() makes the intent explicit and works even if the monkey-patch is not applied.

Suggested change
if device is None:
# Use default CUDA device
device = torch.get_default_device()
if device.type != "cuda":
device = torch.device("cuda", torch.cuda.current_device())
if device.type != te_device_type():
device = torch.device(te_device_type(), torch.cuda.current_device())
elif not isinstance(device, torch.device):
device = torch.device(device)
if device.type == "cuda" and device.index is None:
device = torch.device("cuda", torch.cuda.current_device())
if device.type == te_device_type() and device.index is None:
device = torch.device(te_device_type(), torch.cuda.current_device())
if device is None:
# Use default CUDA device
device = torch.get_default_device()
if device.type != te_device_type():
device = torch.device(te_device_type(), te_platform().current_device())
elif not isinstance(device, torch.device):
device = torch.device(device)
if device.type == te_device_type() and device.index is None:
device = torch.device(te_device_type(), te_platform().current_device())

super().__init__()
if not torch.cuda.is_available():
raise RuntimeError("TransformerEngine needs CUDA.")
assert te_platform().is_available(), f"TransformerEngine needs {te_device_type()}."

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 assert for runtime environment checks is stripped when Python runs with the -O flag, silently bypassing the guard. The original code used an explicit raise RuntimeError, which fires unconditionally. Replacing it with assert means a user running python -O can instantiate a TransformerEngineBaseModule on a system without the required hardware, only to get a confusing crash later deep inside a CUDA/MUSA kernel.

Suggested change
assert te_platform().is_available(), f"TransformerEngine needs {te_device_type()}."
if not te_platform().is_available():
raise RuntimeError(f"TransformerEngine needs {te_device_type()}.")

@lxd-cumt

lxd-cumt commented Jun 17, 2026

Copy link
Copy Markdown
Author

Thanks for the review! I've addressed the above Greptile comments.

lxd-cumt and others added 5 commits June 17, 2026 15:53
Signed-off-by: Xianduo Li <lixianduo@mail.nankai.edu.cn>
Signed-off-by: Xianduo Li <lixianduo@mail.nankai.edu.cn>
Signed-off-by: Xianduo Li <lixianduo@mail.nankai.edu.cn>
…th explicit raise

Signed-off-by: Xianduo Li <lixianduo@mail.nankai.edu.cn>
Signed-off-by: Xianduo Li <lixianduo@mail.nankai.edu.cn>
@lxd-cumt

Copy link
Copy Markdown
Author

Should this PR be merged into the main branch, so it can follow future NVIDIA TE releases?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant