Description
The CUDA EP unfused Attention runner produces NaN logits for fp16 models when head_dim > 256. This affects models like Gemma 4 which use global_head_dim=512 for full-attention layers.
Dispatch chain
When head_dim > 256, Flash Attention cannot dispatch (hard limit at MAX_HEAD_SIZE=256). The fallback path is:
- Flash Attention: ❌ blocked by
head_dim > 256
- Memory-Efficient Attention (MEA): ❌ blocked when
past_key != nullptr
- Unfused runner: ✅ dispatches (MHA,
q_num_heads == kv_num_heads) → produces NaN
Reproduction
Build an ONNX model with Attention op using:
q_num_heads=8, kv_num_heads=8 (MHA after KV head expansion)
head_dim=512 (global attention dimension)
scale=1.0, softcap=0.0, is_causal=1
- dtype:
float16
Run on CUDA EP with a sequence length > ~3 tokens. The logits contain NaN.
Key observations:
- CPU EP produces correct results with the same model (both f32 and f16)
- Sequences of 2-3 tokens work correctly on CUDA
- Sequences ≥ 6 tokens produce NaN on CUDA
- Removing
attn_mask does not help (still NaN)
graph_optimization_level=basic prevents crashes but NaN persists
Model
Gemma 4 E2B (google/gemma-4-E2B-it) has a hybrid attention architecture:
- Sliding-attention layers (
head_dim=256): Work correctly via GroupQueryAttention with local_window_size
- Full-attention layers (
head_dim=512): Hit the unfused runner → NaN
Expected behavior
The unfused Attention runner should produce correct results for fp16 inputs with head_dim=512.
Environment
- onnxruntime-gpu: 1.24.0 (CUDA 12)
- GPU: NVIDIA (SM >= 80)
- Python: 3.12
Urgency/Impact
This blocks CUDA EP inference for Gemma 4 and any model with head_dim > 256 in fp16. The only workaround is to run on CPU or avoid full-attention layers.
Description
The CUDA EP unfused Attention runner produces NaN logits for fp16 models when
head_dim > 256. This affects models like Gemma 4 which useglobal_head_dim=512for full-attention layers.Dispatch chain
When
head_dim > 256, Flash Attention cannot dispatch (hard limit atMAX_HEAD_SIZE=256). The fallback path is:head_dim > 256past_key != nullptrq_num_heads == kv_num_heads) → produces NaNReproduction
Build an ONNX model with
Attentionop using:q_num_heads=8,kv_num_heads=8(MHA after KV head expansion)head_dim=512(global attention dimension)scale=1.0,softcap=0.0,is_causal=1float16Run on CUDA EP with a sequence length > ~3 tokens. The logits contain NaN.
Key observations:
attn_maskdoes not help (still NaN)graph_optimization_level=basicprevents crashes but NaN persistsModel
Gemma 4 E2B (
google/gemma-4-E2B-it) has a hybrid attention architecture:head_dim=256): Work correctly viaGroupQueryAttentionwithlocal_window_sizehead_dim=512): Hit the unfused runner → NaNExpected behavior
The unfused Attention runner should produce correct results for fp16 inputs with
head_dim=512.Environment
Urgency/Impact
This blocks CUDA EP inference for Gemma 4 and any model with
head_dim > 256in fp16. The only workaround is to run on CPU or avoid full-attention layers.