Skip to content

[DRAFT] [Plugin EP] Port graph capture/replay APIs#27958

Draft
adrianlizarraga wants to merge 10 commits intomainfrom
adrianl/PluginEp_CudaGraphCaptureReplay
Draft

[DRAFT] [Plugin EP] Port graph capture/replay APIs#27958
adrianlizarraga wants to merge 10 commits intomainfrom
adrianl/PluginEp_CudaGraphCaptureReplay

Conversation

@adrianlizarraga
Copy link
Copy Markdown
Contributor

Description

Not ready for review

Motivation and Context

OrtValue* dst_ptr = dst;
Ort::Status status(Ort::GetApi().CopyTensors(*ort_env, &src_ptr, &dst_ptr, nullptr, 1));
if (!status.IsOK()) {
ORT_CXX_API_THROW(status.GetErrorMessage(), ORT_FAIL);
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably return a status and use a gtest assert that checks that the status is ok.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR (draft) ports graph capture/replay support to the Plugin EP pathway by extending the OrtEp C API, wiring those callbacks through the plugin EP provider wrapper, and updating session initialization logic to validate/capture based on an EP-provided node-assignment policy.

Changes:

  • Bump ORT version/API version to 1.26.0 / ORT_API_VERSION=26 and add new OrtEp graph capture/replay callbacks plus OrtGraphCaptureNodeAssignmentPolicy.
  • Update InferenceSession::Initialize() to select any EP with graph capture enabled, validate graph assignment via EP policy, and cache a single EP for replay.
  • Add/extend tests for plugin EP graph capture APIs and add an end-to-end autoep WebGPU plugin EP graph capture/replay test.

Reviewed changes

Copilot reviewed 18 out of 18 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
VERSION_NUMBER Bumps runtime version to 1.26.0.
include/onnxruntime/core/session/onnxruntime_c_api.h Bumps ORT_API_VERSION to 26.
onnxruntime/core/session/onnxruntime_c_api.cc Updates version string static assert to 1.26.0.
include/onnxruntime/core/session/onnxruntime_ep_c_api.h Adds graph capture/replay callbacks to OrtEp and introduces OrtGraphCaptureNodeAssignmentPolicy.
include/onnxruntime/core/framework/execution_provider.h Adds IExecutionProvider::GetGraphCaptureNodeAssignmentPolicy() with a strict default.
onnxruntime/core/session/inference_session.cc Generalizes graph-capture EP selection/validation and uses EP-specified node assignment policy.
onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h Exposes graph capture/replay APIs on PluginExecutionProvider.
onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc Implements plugin-side forwarding for graph capture/replay and policy query with version gating.
onnxruntime/core/providers/webgpu/ep/ep.h Declares plugin adapter entrypoints for graph capture/replay and assignment policy.
onnxruntime/core/providers/webgpu/ep/ep.cc Wires WebGPU plugin EP adapter function pointers and forwards to EP impl.
onnxruntime/core/providers/webgpu/webgpu_execution_provider.h Returns ALLOW_CPU_FOR_SHAPES policy for WebGPU EP.
onnxruntime/core/providers/js/js_execution_provider.h Returns ALLOW_CPU_FOR_SHAPES policy for JS EP.
onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h Returns ALLOW_CPU_FOR_SHAPES policy for DML EP wrapper.
onnxruntime/core/providers/cuda/cuda_execution_provider.h Returns ALLOW_CPU_FOR_SHAPES policy for CUDA EP.
onnxruntime/core/providers/cuda/plugin/cuda_ep.h Declares plugin CUDA EP adapter entrypoints for graph capture/replay and policy.
onnxruntime/core/providers/cuda/plugin/cuda_ep.cc Wires plugin CUDA EP adapter function pointers (currently stubbed).
onnxruntime/test/framework/ep_plugin_provider_test.cc Adds unit tests for plugin EP graph capture/replay function-pointer behavior and version gating.
onnxruntime/test/autoep/test_graph_capture.cc Adds end-to-end test exercising WebGPU plugin EP graph capture + replay via public APIs.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

if (!ep->IsGraphCaptureEnabled()) {
continue;
}

Copy link

Copilot AI Apr 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Graph capture selection now iterates over all EPs and caches the first one that returns IsGraphCaptureEnabled(). Some EPs (e.g., NvTensorRTRTX EP) currently return true for IsGraphCaptureEnabled() but intentionally always return false for IsGraphCaptured() and never implement ORT-managed replay; if such an EP is cached, Run() will repeatedly recurse for “warm-up/capture” because IsGraphCaptured() never becomes true. Consider tightening the contract (only return true when ORT is allowed to manage capture/replay), updating those EPs to return false, or adding an explicit capability flag so ORT only caches EPs that will eventually report IsGraphCaptured()=true and support ReplayGraph().

Suggested change
// Some EPs may expose graph capture as an internal/provider-managed feature while intentionally
// never transitioning IsGraphCaptured() to true for ORT-managed replay. Skip those EPs here so
// InferenceSession does not cache them and repeatedly recurse through warm-up/capture in Run().
const auto& ep_type = ep->Type();
if (ep_type.find("NvTensorRTRTX") != std::string::npos) {
LOGS(*session_logger_, WARNING)
<< "Skipping ORT-managed graph capture selection for execution provider " << ep_type
<< " because it does not support ORT-managed graph replay.";
continue;
}

Copilot uses AI. Check for mistakes.
Comment on lines +21 to +26
Ort::ConstEpDevice FindWebGpuEpDevice() {
auto ep_devices = ort_env->GetEpDevices();
for (const auto& device : ep_devices) {
if (std::string(device.EpName()) == "WebGpuExecutionProvider") {
return device;
}
Copy link

Copilot AI Apr 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The EP name check allocates a std::string and hard-codes the provider name. Prefer comparing via std::string_view (or strcmp) against the existing constant (kWebGpuExecutionProvider from core/graph/constants.h) to avoid allocation and reduce brittleness if the name ever changes.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants