Skip to content

CUDA EP: Unfused Attention runner produces NaN for fp16 with head_dim > 256 #28195

@justinchuby

Description

@justinchuby

Description

The CUDA EP unfused Attention runner produces NaN logits for fp16 models when head_dim > 256. This affects models like Gemma 4 which use global_head_dim=512 for full-attention layers.

Dispatch chain

When head_dim > 256, Flash Attention cannot dispatch (hard limit at MAX_HEAD_SIZE=256). The fallback path is:

  1. Flash Attention: ❌ blocked by head_dim > 256
  2. Memory-Efficient Attention (MEA): ❌ blocked when past_key != nullptr
  3. Unfused runner: ✅ dispatches (MHA, q_num_heads == kv_num_heads) → produces NaN

Reproduction

Build an ONNX model with Attention op using:

  • q_num_heads=8, kv_num_heads=8 (MHA after KV head expansion)
  • head_dim=512 (global attention dimension)
  • scale=1.0, softcap=0.0, is_causal=1
  • dtype: float16

Run on CUDA EP with a sequence length > ~3 tokens. The logits contain NaN.

Key observations:

  • CPU EP produces correct results with the same model (both f32 and f16)
  • Sequences of 2-3 tokens work correctly on CUDA
  • Sequences ≥ 6 tokens produce NaN on CUDA
  • Removing attn_mask does not help (still NaN)
  • graph_optimization_level=basic prevents crashes but NaN persists

Model

Gemma 4 E2B (google/gemma-4-E2B-it) has a hybrid attention architecture:

  • Sliding-attention layers (head_dim=256): Work correctly via GroupQueryAttention with local_window_size
  • Full-attention layers (head_dim=512): Hit the unfused runner → NaN

Expected behavior

The unfused Attention runner should produce correct results for fp16 inputs with head_dim=512.

Environment

  • onnxruntime-gpu: 1.24.0 (CUDA 12)
  • GPU: NVIDIA (SM >= 80)
  • Python: 3.12

Urgency/Impact

This blocks CUDA EP inference for Gemma 4 and any model with head_dim > 256 in fp16. The only workaround is to run on CPU or avoid full-attention layers.

Metadata

Metadata

Assignees

No one assigned

    Labels

    ep:CUDAissues related to the CUDA execution provider

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions