Abstract CUDA hardcodes into configurable te_device_type / te_platform#3113
Abstract CUDA hardcodes into configurable te_device_type / te_platform#3113lxd-cumt wants to merge 5 commits into
Conversation
Greptile SummaryThis PR introduces a soft device-type abstraction layer for TransformerEngine's PyTorch backend, replacing ~200 hardcoded
Confidence Score: 4/5Safe to merge on standard CUDA deployments; the MUSA plugin path has several minor inconsistencies that could silently disable optimisations or surface incorrect error messages on alternative backends. All changes on the standard CUDA path are mechanical string substitutions that evaluate to the same value, so no existing behaviour changes. The MUSA plugin path has a dead-code ternary in Float8BlockQuantizer.make_empty, a double-space error message, a potential LRU-cache device confusion in get_rht_matrix, and two hard-coded torch.cuda.* guards that survive the device-type check in GroupedTensorStorage — none of these affect CUDA users, but they leave the MUSA path incompletely abstracted. transformer_engine/pytorch/tensor/nvfp4_tensor.py (LRU cache + device move interaction), transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py (mixed device guards), transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py (double-space typo). Important Files Changed
Sequence Diagram%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
participant Init as transformer_engine/__init__.py
participant Plugin as transformer_engine_plugin_fl.patches
participant PyTorch as transformer_engine/pytorch/*
participant User as User Code
Init->>Init: "TE_DEVICE_TYPE = "cuda""
Init->>Init: "TE_PLATFORM = torch.cuda"
alt "NVTE_ENABLE_PLUGIN=1"
Init->>Plugin: apply_patches()
Plugin->>Init: "TE_DEVICE_TYPE = "musa""
Plugin->>Init: "TE_PLATFORM = torch.musa"
Plugin->>Plugin: "monkey-patch torch.cuda.* → torch.musa.*"
end
Init->>PyTorch: from . import pytorch
Note over PyTorch: Default args like device=te_device_type() captured here (after patch)
PyTorch->>PyTorch: "gpu_autocast_ctx = functools.partial(..., device_type=te_device_type())"
User->>PyTorch: "TransformerLayer(device=te_device_type())"
PyTorch->>PyTorch: "torch.cuda.* calls resolved via monkey-patch or native CUDA"
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
participant Init as transformer_engine/__init__.py
participant Plugin as transformer_engine_plugin_fl.patches
participant PyTorch as transformer_engine/pytorch/*
participant User as User Code
Init->>Init: "TE_DEVICE_TYPE = "cuda""
Init->>Init: "TE_PLATFORM = torch.cuda"
alt "NVTE_ENABLE_PLUGIN=1"
Init->>Plugin: apply_patches()
Plugin->>Init: "TE_DEVICE_TYPE = "musa""
Plugin->>Init: "TE_PLATFORM = torch.musa"
Plugin->>Plugin: "monkey-patch torch.cuda.* → torch.musa.*"
end
Init->>PyTorch: from . import pytorch
Note over PyTorch: Default args like device=te_device_type() captured here (after patch)
PyTorch->>PyTorch: "gpu_autocast_ctx = functools.partial(..., device_type=te_device_type())"
User->>PyTorch: "TransformerLayer(device=te_device_type())"
PyTorch->>PyTorch: "torch.cuda.* calls resolved via monkey-patch or native CUDA"
Reviews (4): Last reviewed commit: "fix: patch cuda hardcodes for plugin com..." | Re-trigger Greptile |
| dpa_utils._original_get_attention_backend = dpa_utils.get_attention_backend | ||
| # Replace dpa_utils.get_attention_backend with tex.get_attention_backend | ||
| # This allows each backend (FlagOS, CUDA, Reference) to control its own backend selection | ||
| dpa_utils.get_attention_backend = tex.get_attention_backend |
There was a problem hiding this comment.
AttributeError at import time breaks the PyTorch module
tex is transformer_engine_torch (the C++ extension), which does not expose a get_attention_backend attribute. Accessing tex.get_attention_backend directly on line 75 (without a getattr guard) raises AttributeError the moment transformer_engine.pytorch is imported, making the entire PyTorch backend unusable on any standard CUDA installation. The previous line (69) correctly uses getattr(tex, "flash_attention", _FlashAttentionNative) with a fallback — the same pattern must be applied here, or the unconditional attribute access must be removed.
| try: | ||
| from transformer_engine_torch import bulk_overlap_ag_with_external_gemm | ||
| except ImportError: | ||
| bulk_overlap_ag_with_external_gemm = None |
There was a problem hiding this comment.
NoneType call crash in backward pass when bulk_overlap_ag_with_external_gemm is unavailable
The import is now guarded (= None on failure), but line 435 calls bulk_overlap_ag_with_external_gemm(ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream) unconditionally. Whenever this code path is hit in a tensor-parallel row-overlap backward pass on a system where this symbol is absent, a TypeError: 'NoneType' object is not callable is raised at runtime rather than at import time. A guard like if bulk_overlap_ag_with_external_gemm is not None: (or raising a descriptive error earlier) is needed at the call site.
| try: | ||
| from .plugin.core.backends.vendor.musa.patches import apply_patch as _musa_apply_patch | ||
|
|
||
| _musa_apply_patch() | ||
| except Exception as e: | ||
| pass |
There was a problem hiding this comment.
Silent
except Exception: pass hides all MUSA patch failures
The bare except Exception as e: pass swallows every failure during the MUSA patch import — including AttributeError raised when torch.musa.* attributes referenced in _PATCH_CALLS don't exist on standard CUDA systems. The variable e is never logged or inspected. On CUDA systems this is the common path (no torch.musa), so every import of transformer_engine silently triggers and discards an exception. At minimum, emit a logging.debug or use a narrower exception type (e.g., ImportError) and let other errors propagate.
| # Patches: (parent_object, attribute_name, replacement_callable) | ||
| _PATCH_CALLS: list[tuple[object, str, Callable[..., object]]] = [ | ||
| # We do not recommend replace is_available, due to its device-related behavior. | ||
| # (torch.cuda, "is_available", torch.musa.is_available), | ||
| (torch.cuda, "get_device_properties", torch.musa.get_device_properties), | ||
| (torch.cuda, "device", torch.musa.device), | ||
| (torch.cuda, "current_device", torch.musa.current_device), | ||
| (torch.cuda, "synchronize", torch.musa.synchronize), | ||
| (torch.cuda, "is_current_stream_capturing", torch.musa.is_current_stream_capturing), | ||
| # TODO: Add NVTX patches for MUSA. | ||
| # NVTX is CUDA-specific; make it a no-op on MUSA. | ||
| (torch.cuda.nvtx, "range_push", _noop), | ||
| (torch.cuda.nvtx, "range_pop", _noop), |
There was a problem hiding this comment.
_PATCH_CALLS accesses torch.musa.* at module load time
_PATCH_CALLS is a module-level list that dereferences torch.musa.get_device_properties, torch.musa.device, etc. when patches.py is imported. On any system without torch_musa, this raises AttributeError the moment the import is attempted. The caller in __init__.py wraps this in a blanket except Exception: pass, so it fails silently, but it means the module is broken by construction on non-MUSA hosts. Deferring the attribute lookups to inside apply_patch() (where hasattr(torch, "musa") is already checked) would make the module safe to import on all platforms.
| if device.type != te_device_type(): | ||
| raise ValueError(f"Only CUDA devices are supported (got {device})") |
There was a problem hiding this comment.
The condition now guards against non-
te_device_type() devices, but the error message still hard-codes "CUDA". On a MUSA system a user would see "Only CUDA devices are supported" when their device is actually valid, which is misleading.
| if device.type != te_device_type(): | |
| raise ValueError(f"Only CUDA devices are supported (got {device})") | |
| if device.type != te_device_type(): | |
| raise ValueError(f"Only {te_device_type().upper()} devices are supported (got {device})") |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| if device.type != te_device_type(): | ||
| raise ValueError(f"Only CUDA devices are supported (got {device})") |
There was a problem hiding this comment.
Same stale error message: the guard now checks
te_device_type() but the message still says "CUDA".
| if device.type != te_device_type(): | |
| raise ValueError(f"Only CUDA devices are supported (got {device})") | |
| if device.type != te_device_type(): | |
| raise ValueError(f"Only {te_device_type().upper()} devices are supported (got {device})") |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| # Mark TE global device type for Python-side callers. | ||
| # IMPORTANT: do not import `transformer_engine` here, because TE's `__init__.py` | ||
| # imports this module to run patches and that would cause a circular import. | ||
| try: | ||
| import transformer_engine | ||
|
|
||
| transformer_engine.TE_DEVICE_TYPE = "musa" | ||
| transformer_engine.TE_PLATFORM = torch.musa |
There was a problem hiding this comment.
The comment warns against importing
transformer_engine here due to a circular-import risk, but the very next lines do exactly that. During transformer_engine/__init__.py execution Python's import cache returns the partially-initialized module, so TE_DEVICE_TYPE (set before the patch call) is reachable and the assignment works — but the comment creates a false sense of safety and the approach is still fragile if the import order ever changes.
| # Mark TE global device type for Python-side callers. | |
| # IMPORTANT: do not import `transformer_engine` here, because TE's `__init__.py` | |
| # imports this module to run patches and that would cause a circular import. | |
| try: | |
| import transformer_engine | |
| transformer_engine.TE_DEVICE_TYPE = "musa" | |
| transformer_engine.TE_PLATFORM = torch.musa | |
| # Mark TE global device type for Python-side callers. | |
| # NOTE: importing `transformer_engine` here re-enters a partially-initialised module | |
| # (its __init__.py is still running), but Python's import cache makes this safe as long | |
| # as TE_DEVICE_TYPE is assigned before apply_patch() is called. | |
| try: | |
| import transformer_engine | |
| transformer_engine.TE_DEVICE_TYPE = "musa" | |
| transformer_engine.TE_PLATFORM = torch.musa |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| if device is None: | ||
| # Use default CUDA device | ||
| device = torch.get_default_device() | ||
| if device.type != "cuda": | ||
| device = torch.device("cuda", torch.cuda.current_device()) | ||
| if device.type != te_device_type(): | ||
| device = torch.device(te_device_type(), torch.cuda.current_device()) | ||
| elif not isinstance(device, torch.device): | ||
| device = torch.device(device) | ||
| if device.type == "cuda" and device.index is None: | ||
| device = torch.device("cuda", torch.cuda.current_device()) | ||
| if device.type == te_device_type() and device.index is None: | ||
| device = torch.device(te_device_type(), torch.cuda.current_device()) |
There was a problem hiding this comment.
After monkey-patching,
torch.cuda.current_device() is still called directly here. On MUSA hosts this works only because the patch replaces torch.cuda.current_device with torch.musa.current_device. Using te_platform().current_device() makes the intent explicit and works even if the monkey-patch is not applied.
| if device is None: | |
| # Use default CUDA device | |
| device = torch.get_default_device() | |
| if device.type != "cuda": | |
| device = torch.device("cuda", torch.cuda.current_device()) | |
| if device.type != te_device_type(): | |
| device = torch.device(te_device_type(), torch.cuda.current_device()) | |
| elif not isinstance(device, torch.device): | |
| device = torch.device(device) | |
| if device.type == "cuda" and device.index is None: | |
| device = torch.device("cuda", torch.cuda.current_device()) | |
| if device.type == te_device_type() and device.index is None: | |
| device = torch.device(te_device_type(), torch.cuda.current_device()) | |
| if device is None: | |
| # Use default CUDA device | |
| device = torch.get_default_device() | |
| if device.type != te_device_type(): | |
| device = torch.device(te_device_type(), te_platform().current_device()) | |
| elif not isinstance(device, torch.device): | |
| device = torch.device(device) | |
| if device.type == te_device_type() and device.index is None: | |
| device = torch.device(te_device_type(), te_platform().current_device()) |
| super().__init__() | ||
| if not torch.cuda.is_available(): | ||
| raise RuntimeError("TransformerEngine needs CUDA.") | ||
| assert te_platform().is_available(), f"TransformerEngine needs {te_device_type()}." |
There was a problem hiding this comment.
assert for runtime environment checks is stripped when Python runs with the -O flag, silently bypassing the guard. The original code used an explicit raise RuntimeError, which fires unconditionally. Replacing it with assert means a user running python -O can instantiate a TransformerEngineBaseModule on a system without the required hardware, only to get a confusing crash later deep inside a CUDA/MUSA kernel.
| assert te_platform().is_available(), f"TransformerEngine needs {te_device_type()}." | |
| if not te_platform().is_available(): | |
| raise RuntimeError(f"TransformerEngine needs {te_device_type()}.") |
|
Thanks for the review! I've addressed the above Greptile comments. |
Signed-off-by: Xianduo Li <lixianduo@mail.nankai.edu.cn>
Signed-off-by: Xianduo Li <lixianduo@mail.nankai.edu.cn>
Signed-off-by: Xianduo Li <lixianduo@mail.nankai.edu.cn>
…th explicit raise Signed-off-by: Xianduo Li <lixianduo@mail.nankai.edu.cn>
Signed-off-by: Xianduo Li <lixianduo@mail.nankai.edu.cn>
|
Should this PR be merged into the |
FlagOS Proposal: Plugin Architecture & Device-Agnostic Abstraction for TransformerEngine
Device-Type Abstraction: Replacing Hardcoded
"cuda"ReferencesThe current TE PyTorch layer contains ~100 hardcoded
"cuda"string literals and ~165torch.cuda.*API calls. These span device placement (device="cuda"), autocast context (device_type="cuda"), device-type guards (device.type == "cuda"), and RNG state management (torch.cuda.CUDAGraph,torch.cuda._lazy_call). This makes TE non-functional on alternative accelerator platforms without invasive patching.Proposed Design
Soft abstraction – A global
te_device_type()/te_platform()accessor replaces ~200 literal"cuda"strings across the Python codebase.Platform monkey-patch – A vendor-provided
apply_patch()hook runs at import time to directly remaptorch.cuda.*APIs (e.g.torch.cuda.device,torch.cuda.current_device,torch.cuda.current_stream) to the vendor equivalents (e.g.torch.other_vendor.*).