-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Attention op (opset 23) fails with bfloat16 on CPU and CUDA due to function body type mismatches #27891
Description
Summary
When the ONNX opset-23 Attention op is used with bfloat16 inputs, ORT fails in two ways:
-
CPU EP – The BFloat16 kernel is not registered, so ORT falls back to AOT function-body inlining. The inlined graph contains an
Expand(13)node for which CPU EP has no bfloat16 kernel:NOT_IMPLEMENTED: Could not find an implementation for Expand(13) node with name '' -
CUDA EP (function-body fallback) – When Flash/MEA are unavailable and the CUDA kernel returns
NOT_IMPLEMENTED, ORT tries to inline the function body. The ONNX Attention function body (inonnx/defs/nn/utils.cc) createsFloatNegInfandScalarZeroas float32 constants, then uses them in:MaskTri = Where(BoolMaskTri, FloatNegInf, ScalarZero) # float32 AttnBiasCausalOrNot = Add(AttnBias, MaskTri) # bfloat16 + float32 → TYPE MISMATCHError:
Type parameter (T) of Optype (Add) bound to different types (tensor(bfloat16) and tensor(float))
Minimal Repro
import numpy as np, onnx, onnxruntime as ort
from onnx import helper, TensorProto
B, S, H, d_k = 1, 4, 2, 8
inputs = [helper.make_tensor_value_info(n, TensorProto.BFLOAT16, [B, S, H, d_k]) for n in ['Q','K','V']]
out = helper.make_tensor_value_info('Y', TensorProto.BFLOAT16, [B, S, H, d_k])
node = helper.make_node('Attention', ['Q','K','V'], ['Y'], q_num_heads=H, kv_num_heads=H)
graph = helper.make_graph([node], 'g', inputs, [out])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid('', 23)])
model.ir_version = 8
sess = ort.InferenceSession(model.SerializeToString(), providers=['CPUExecutionProvider'])
# → NOT_IMPLEMENTED: Could not find an implementation for Expand(13) node with name ''Root Cause Analysis
CPU EP
onnxruntime/core/providers/cpu/llm/attention.cc registers Attention for float and MLFloat16 only — BFloat16 is missing. When no kernel matches, ORT falls back to function-body inlining via TryGetFunctionProto. The inlined graph uses Expand (and other ops) for which the CPU EP has no bfloat16 kernels.
Function body type mismatch
The ONNX Attention function body builder (onnx/defs/nn/utils.cc::AttentionAppendFunctionCausalMask) creates causal-mask constants as hard-coded float32:
float neg_inf = -std::numeric_limits<float>::infinity();
builder.Const1D("FloatNegInf", neg_inf); // float32
builder.Const1D("ScalarZero", 0.f); // float32When attn_mask is provided as bfloat16 (or the bias is bfloat16 from ConstantOfShape path) and is_causal=1, the downstream Add(AttnBias, MaskTri) gets mixed types.
Proposed Fix
CPU EP: Register BFloat16 for the CPU Attention kernel (opset 23 & 24), upcasting to float32 for internal computation (same pattern as MLFloat16).
Function body: Add CastLike to match MaskTri to the type of AttnBias before the Add:
MaskTriTyped = CastLike(MaskTri, AttnBias)
AttnBiasCausalOrNot = Add(AttnBias, MaskTriTyped)
Environment
- ORT 1.24.4 (main branch at commit
aadf724) - ONNX 1.20.1
- CPU EP (x86_64 Linux)