Ascend NPU Support: Fused Operators and Flash Linear Attention

Jun 5, 2026 · 3 min read
blog Technical Deep Dive

Twinkle provides first-class support for Huawei Ascend NPU through a comprehensive monkey-patching system that replaces standard CUDA operators with NPU-optimized fused kernels. This post covers the kernel architecture and the optimizations enabled.

Kernel Architecture

Twinkle’s kernel module (twinkle.kernel) provides a unified entry point kernelize_model() that automatically detects the device and applies appropriate optimizations:

from twinkle.kernel import kernelize_model
model = kernelize_model(model, device='npu')  # or auto-detected

On NPU devices, the following fused operators are applied unconditionally:

OperatorNPU ImplementationBenefit
RMSNormtorch_npu.npu_rms_normFused normalization, ~2x faster
RoPEtorch_npu.npu_rotary_mulFused rotary embedding with partial RoPE support
SwiGLUtorch_npu.npu_swigluFused gate+up projection activation
SDPANPU-compatible scaled_dot_product_attentionCorrect mask handling for NPU
MoE GMMtorch_npu.npu_grouped_matmulEP-aware grouped matrix multiply
FLAMindSpeed Triton backendFlash Linear Attention for Qwen3.5

Fused Operators in Detail

RMSNorm with Residual Parameterization

Twinkle’s NpuRMSNorm detects the residual parameterization pattern used by Qwen3.5 (where scale = 1.0 + weight) at initialization time, avoiding CPU-synchronizing Tensor.item() calls in the hot path:

class NpuRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        self.weight = nn.Parameter(torch.ones(hidden_size))
        # Detect once at init
        self._residual_param = abs(self.weight.data.mean().item()) < 0.3

    def forward(self, hidden_states):
        scale = (1.0 + self.weight) if self._residual_param else self.weight
        return torch_npu.npu_rms_norm(hidden_states, scale, epsilon=self.eps)[0]

EP-Aware MoE Optimization

The MoE grouped matmul patch is EP-aware — it only activates when Expert Parallelism is enabled (each rank holds a subset of experts, weights are small and contiguous). Without EP, each rank holds all experts, and the transpose+contiguous copy creates ~8x overhead:

TWINKLE_NPU_GMM_PATCH not set → skip (default safe)
TWINKLE_NPU_GMM_PATCH=1 + EP enabled  → apply (efficient)
TWINKLE_NPU_GMM_PATCH=1 + EP disabled → skip (avoid 8x overhead)

The GmmFunction autograd function wraps torch_npu.npu_grouped_matmul with full backward support, and weights are cached with automatic invalidation when updated (full-param training bumps _version, LoRA keeps it stable).

Flash Linear Attention for Qwen3.5

Qwen3.5 introduces a hybrid architecture mixing standard attention with linear attention layers. Twinkle enables the FLA fast path on NPU via MindSpeed’s Triton implementation of chunk_gated_delta_rule:

  1. Force is_flash_linear_attention_available = True in transformers
  2. Replace chunk_gated_delta_rule with MindSpeed NPU-compatible implementation
  3. Traverse instantiated model to patch per-layer instances
  4. Disable CUDA-only FusedRMSNormGated that would fail on NPU

The MindSpeed implementation provides chunked forward/backward with WY representation, supporting variable-length sequences via cu_seqlens.

Environment Variable Control

Every optimization is independently controllable:

VariableDefaultDescription
TWINKLE_NPU_PATCH1Master switch for all NPU patches
TWINKLE_NPU_FUSED_OPS1Fused operators (RMSNorm/RoPE/SwiGLU/SDPA)
TWINKLE_NPU_GMM_PATCHunsetMoE grouped matmul (EP-aware)
TWINKLE_NPU_FLA1Flash Linear Attention
TWINKLE_NPU_GATED_RMSNorm_FP320Force FP32 for Gated RMSNorm

Supported Model Families

The patching system automatically discovers and patches compatible model families:

  • Qwen3 / Qwen3-MoE — Full operator fusion
  • Qwen3.5 / Qwen3.5-MoE — Full fusion + FLA + Gated RMSNorm
  • Qwen2.5-VL — Full fusion + multimodal RoPE
  • Dynamic discovery — Unknown models are scanned for compatible RMSNorm/RoPE/SwiGLU patterns

Getting Started

# Install NPU dependencies
pip install torch-npu mindspeed

# Training automatically uses NPU optimizations
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 train.py

See the NPU Support Guide for detailed setup instructions.