Skip to content

Attention op (opset 23) fails with bfloat16 on CPU and CUDA due to function body type mismatches #27891

@justinchuby

Description

@justinchuby

Summary

When the ONNX opset-23 Attention op is used with bfloat16 inputs, ORT fails in two ways:

  1. CPU EP – The BFloat16 kernel is not registered, so ORT falls back to AOT function-body inlining. The inlined graph contains an Expand(13) node for which CPU EP has no bfloat16 kernel:

    NOT_IMPLEMENTED: Could not find an implementation for Expand(13) node with name ''
    
  2. CUDA EP (function-body fallback) – When Flash/MEA are unavailable and the CUDA kernel returns NOT_IMPLEMENTED, ORT tries to inline the function body. The ONNX Attention function body (in onnx/defs/nn/utils.cc) creates FloatNegInf and ScalarZero as float32 constants, then uses them in:

    MaskTri = Where(BoolMaskTri, FloatNegInf, ScalarZero)        # float32
    AttnBiasCausalOrNot = Add(AttnBias, MaskTri)                 # bfloat16 + float32 → TYPE MISMATCH
    

    Error: Type parameter (T) of Optype (Add) bound to different types (tensor(bfloat16) and tensor(float))

Minimal Repro

import numpy as np, onnx, onnxruntime as ort
from onnx import helper, TensorProto

B, S, H, d_k = 1, 4, 2, 8
inputs = [helper.make_tensor_value_info(n, TensorProto.BFLOAT16, [B, S, H, d_k]) for n in ['Q','K','V']]
out = helper.make_tensor_value_info('Y', TensorProto.BFLOAT16, [B, S, H, d_k])
node = helper.make_node('Attention', ['Q','K','V'], ['Y'], q_num_heads=H, kv_num_heads=H)
graph = helper.make_graph([node], 'g', inputs, [out])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid('', 23)])
model.ir_version = 8

sess = ort.InferenceSession(model.SerializeToString(), providers=['CPUExecutionProvider'])
# → NOT_IMPLEMENTED: Could not find an implementation for Expand(13) node with name ''

Root Cause Analysis

CPU EP

onnxruntime/core/providers/cpu/llm/attention.cc registers Attention for float and MLFloat16 only — BFloat16 is missing. When no kernel matches, ORT falls back to function-body inlining via TryGetFunctionProto. The inlined graph uses Expand (and other ops) for which the CPU EP has no bfloat16 kernels.

Function body type mismatch

The ONNX Attention function body builder (onnx/defs/nn/utils.cc::AttentionAppendFunctionCausalMask) creates causal-mask constants as hard-coded float32:

float neg_inf = -std::numeric_limits<float>::infinity();
builder.Const1D("FloatNegInf", neg_inf);   // float32
builder.Const1D("ScalarZero",  0.f);        // float32

When attn_mask is provided as bfloat16 (or the bias is bfloat16 from ConstantOfShape path) and is_causal=1, the downstream Add(AttnBias, MaskTri) gets mixed types.

Proposed Fix

CPU EP: Register BFloat16 for the CPU Attention kernel (opset 23 & 24), upcasting to float32 for internal computation (same pattern as MLFloat16).

Function body: Add CastLike to match MaskTri to the type of AttnBias before the Add:

MaskTriTyped = CastLike(MaskTri, AttnBias)
AttnBiasCausalOrNot = Add(AttnBias, MaskTriTyped)

Environment

  • ORT 1.24.4 (main branch at commit aadf724)
  • ONNX 1.20.1
  • CPU EP (x86_64 Linux)

Metadata

Metadata

Assignees

No one assigned

    Labels

    ep:CUDAissues related to the CUDA execution provider

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions