[DRAFT] [Plugin EP] Port graph capture/replay APIs#27958
[DRAFT] [Plugin EP] Port graph capture/replay APIs#27958adrianlizarraga wants to merge 10 commits intomainfrom
Conversation
…ve the hardcoded list of EPs that support graph capture from inference_session.cc
| OrtValue* dst_ptr = dst; | ||
| Ort::Status status(Ort::GetApi().CopyTensors(*ort_env, &src_ptr, &dst_ptr, nullptr, 1)); | ||
| if (!status.IsOK()) { | ||
| ORT_CXX_API_THROW(status.GetErrorMessage(), ORT_FAIL); |
There was a problem hiding this comment.
Should probably return a status and use a gtest assert that checks that the status is ok.
There was a problem hiding this comment.
Pull request overview
This PR (draft) ports graph capture/replay support to the Plugin EP pathway by extending the OrtEp C API, wiring those callbacks through the plugin EP provider wrapper, and updating session initialization logic to validate/capture based on an EP-provided node-assignment policy.
Changes:
- Bump ORT version/API version to 1.26.0 /
ORT_API_VERSION=26and add newOrtEpgraph capture/replay callbacks plusOrtGraphCaptureNodeAssignmentPolicy. - Update
InferenceSession::Initialize()to select any EP with graph capture enabled, validate graph assignment via EP policy, and cache a single EP for replay. - Add/extend tests for plugin EP graph capture APIs and add an end-to-end autoep WebGPU plugin EP graph capture/replay test.
Reviewed changes
Copilot reviewed 18 out of 18 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| VERSION_NUMBER | Bumps runtime version to 1.26.0. |
| include/onnxruntime/core/session/onnxruntime_c_api.h | Bumps ORT_API_VERSION to 26. |
| onnxruntime/core/session/onnxruntime_c_api.cc | Updates version string static assert to 1.26.0. |
| include/onnxruntime/core/session/onnxruntime_ep_c_api.h | Adds graph capture/replay callbacks to OrtEp and introduces OrtGraphCaptureNodeAssignmentPolicy. |
| include/onnxruntime/core/framework/execution_provider.h | Adds IExecutionProvider::GetGraphCaptureNodeAssignmentPolicy() with a strict default. |
| onnxruntime/core/session/inference_session.cc | Generalizes graph-capture EP selection/validation and uses EP-specified node assignment policy. |
| onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h | Exposes graph capture/replay APIs on PluginExecutionProvider. |
| onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc | Implements plugin-side forwarding for graph capture/replay and policy query with version gating. |
| onnxruntime/core/providers/webgpu/ep/ep.h | Declares plugin adapter entrypoints for graph capture/replay and assignment policy. |
| onnxruntime/core/providers/webgpu/ep/ep.cc | Wires WebGPU plugin EP adapter function pointers and forwards to EP impl. |
| onnxruntime/core/providers/webgpu/webgpu_execution_provider.h | Returns ALLOW_CPU_FOR_SHAPES policy for WebGPU EP. |
| onnxruntime/core/providers/js/js_execution_provider.h | Returns ALLOW_CPU_FOR_SHAPES policy for JS EP. |
| onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h | Returns ALLOW_CPU_FOR_SHAPES policy for DML EP wrapper. |
| onnxruntime/core/providers/cuda/cuda_execution_provider.h | Returns ALLOW_CPU_FOR_SHAPES policy for CUDA EP. |
| onnxruntime/core/providers/cuda/plugin/cuda_ep.h | Declares plugin CUDA EP adapter entrypoints for graph capture/replay and policy. |
| onnxruntime/core/providers/cuda/plugin/cuda_ep.cc | Wires plugin CUDA EP adapter function pointers (currently stubbed). |
| onnxruntime/test/framework/ep_plugin_provider_test.cc | Adds unit tests for plugin EP graph capture/replay function-pointer behavior and version gating. |
| onnxruntime/test/autoep/test_graph_capture.cc | Adds end-to-end test exercising WebGPU plugin EP graph capture + replay via public APIs. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if (!ep->IsGraphCaptureEnabled()) { | ||
| continue; | ||
| } | ||
|
|
There was a problem hiding this comment.
Graph capture selection now iterates over all EPs and caches the first one that returns IsGraphCaptureEnabled(). Some EPs (e.g., NvTensorRTRTX EP) currently return true for IsGraphCaptureEnabled() but intentionally always return false for IsGraphCaptured() and never implement ORT-managed replay; if such an EP is cached, Run() will repeatedly recurse for “warm-up/capture” because IsGraphCaptured() never becomes true. Consider tightening the contract (only return true when ORT is allowed to manage capture/replay), updating those EPs to return false, or adding an explicit capability flag so ORT only caches EPs that will eventually report IsGraphCaptured()=true and support ReplayGraph().
| // Some EPs may expose graph capture as an internal/provider-managed feature while intentionally | |
| // never transitioning IsGraphCaptured() to true for ORT-managed replay. Skip those EPs here so | |
| // InferenceSession does not cache them and repeatedly recurse through warm-up/capture in Run(). | |
| const auto& ep_type = ep->Type(); | |
| if (ep_type.find("NvTensorRTRTX") != std::string::npos) { | |
| LOGS(*session_logger_, WARNING) | |
| << "Skipping ORT-managed graph capture selection for execution provider " << ep_type | |
| << " because it does not support ORT-managed graph replay."; | |
| continue; | |
| } |
| Ort::ConstEpDevice FindWebGpuEpDevice() { | ||
| auto ep_devices = ort_env->GetEpDevices(); | ||
| for (const auto& device : ep_devices) { | ||
| if (std::string(device.EpName()) == "WebGpuExecutionProvider") { | ||
| return device; | ||
| } |
There was a problem hiding this comment.
The EP name check allocates a std::string and hard-codes the provider name. Prefer comparing via std::string_view (or strcmp) against the existing constant (kWebGpuExecutionProvider from core/graph/constants.h) to avoid allocation and reduce brittleness if the name ever changes.
Description
Not ready for review
Motivation and Context