From 9a3b384be6f5e5542ab98ef5c8aefea4cab969d6 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Fri, 3 Apr 2026 09:54:24 -0700 Subject: [PATCH] Add cel.@block test coverage for parsed-only mode PiperOrigin-RevId: 894125313 --- .../main/java/dev/cel/extensions/BUILD.bazel | 1 + .../cel/extensions/CelBindingsExtensions.java | 14 +- .../CelComprehensionsExtensions.java | 35 ++-- .../extensions/CelBindingsExtensionsTest.java | 3 +- .../dev/cel/optimizer/optimizers/BUILD.bazel | 3 + .../SubexpressionOptimizerBaselineTest.java | 185 ++++++++++++------ .../SubexpressionOptimizerTest.java | 161 ++++++++++++--- ...old_before_subexpression_unparsed.baseline | 2 +- .../resources/subexpression_unparsed.baseline | 2 +- .../java/dev/cel/runtime/planner/BUILD.bazel | 62 +++--- .../cel/runtime/planner/BlockMemoizer.java | 72 +++++++ .../dev/cel/runtime/planner/EvalBlock.java | 67 +++++++ .../cel/runtime/planner/ExecutionFrame.java | 12 ++ .../cel/runtime/planner/ProgramPlanner.java | 59 +++++- 14 files changed, 525 insertions(+), 153 deletions(-) create mode 100644 runtime/src/main/java/dev/cel/runtime/planner/BlockMemoizer.java create mode 100644 runtime/src/main/java/dev/cel/runtime/planner/EvalBlock.java diff --git a/extensions/src/main/java/dev/cel/extensions/BUILD.bazel b/extensions/src/main/java/dev/cel/extensions/BUILD.bazel index ed2d19d6f..77663f2fa 100644 --- a/extensions/src/main/java/dev/cel/extensions/BUILD.bazel +++ b/extensions/src/main/java/dev/cel/extensions/BUILD.bazel @@ -142,6 +142,7 @@ java_library( deps = [ "//common:compiler_common", "//common/ast", + "//common/types", "//compiler:compiler_builder", "//extensions:extension_library", "//parser:macro", diff --git a/extensions/src/main/java/dev/cel/extensions/CelBindingsExtensions.java b/extensions/src/main/java/dev/cel/extensions/CelBindingsExtensions.java index 5eb2c2e8c..0e6537334 100644 --- a/extensions/src/main/java/dev/cel/extensions/CelBindingsExtensions.java +++ b/extensions/src/main/java/dev/cel/extensions/CelBindingsExtensions.java @@ -22,7 +22,11 @@ import com.google.errorprone.annotations.Immutable; import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelIssue; +import dev.cel.common.CelOverloadDecl; import dev.cel.common.ast.CelExpr; +import dev.cel.common.types.ListType; +import dev.cel.common.types.SimpleType; +import dev.cel.common.types.TypeParamType; import dev.cel.compiler.CelCompilerLibrary; import dev.cel.parser.CelMacro; import dev.cel.parser.CelMacroExprFactory; @@ -62,7 +66,15 @@ public int version() { @Override public ImmutableSet functions() { - return ImmutableSet.of(); + // TODO: Add bindings for block once decorator support is available. + return ImmutableSet.of( + CelFunctionDecl.newFunctionDeclaration( + "cel.@block", + CelOverloadDecl.newGlobalOverload( + "cel_block_list", + TypeParamType.create("T"), + ListType.create(SimpleType.DYN), + TypeParamType.create("T")))); } @Override diff --git a/extensions/src/main/java/dev/cel/extensions/CelComprehensionsExtensions.java b/extensions/src/main/java/dev/cel/extensions/CelComprehensionsExtensions.java index 23663f02e..7c298a773 100644 --- a/extensions/src/main/java/dev/cel/extensions/CelComprehensionsExtensions.java +++ b/extensions/src/main/java/dev/cel/extensions/CelComprehensionsExtensions.java @@ -118,29 +118,18 @@ public void setRuntimeOptions(CelRuntimeBuilder runtimeBuilder) { @Override public void setRuntimeOptions( CelRuntimeBuilder runtimeBuilder, RuntimeEquality runtimeEquality, CelOptions celOptions) { - for (Function function : functions) { - for (CelOverloadDecl overload : function.functionDecl.overloads()) { - switch (overload.overloadId()) { - case MAP_INSERT_OVERLOAD_MAP_MAP: - runtimeBuilder.addFunctionBindings( - CelFunctionBinding.from( - MAP_INSERT_OVERLOAD_MAP_MAP, - Map.class, - Map.class, - (map1, map2) -> mapInsertMap(map1, map2, runtimeEquality))); - break; - case MAP_INSERT_OVERLOAD_KEY_VALUE: - runtimeBuilder.addFunctionBindings( - CelFunctionBinding.from( - MAP_INSERT_OVERLOAD_KEY_VALUE, - ImmutableList.of(Map.class, Object.class, Object.class), - args -> mapInsertKeyValue(args, runtimeEquality))); - break; - default: - // Nothing to add. - } - } - } + runtimeBuilder.addFunctionBindings( + CelFunctionBinding.fromOverloads( + MAP_INSERT_FUNCTION, + CelFunctionBinding.from( + MAP_INSERT_OVERLOAD_MAP_MAP, + Map.class, + Map.class, + (map1, map2) -> mapInsertMap(map1, map2, runtimeEquality)), + CelFunctionBinding.from( + MAP_INSERT_OVERLOAD_KEY_VALUE, + ImmutableList.of(Map.class, Object.class, Object.class), + args -> mapInsertKeyValue(args, runtimeEquality)))); } @Override diff --git a/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java index bc98c9816..ff9e31432 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java @@ -63,7 +63,8 @@ public void library() { CelExtensions.getExtensionLibrary("bindings", CelOptions.DEFAULT); assertThat(library.name()).isEqualTo("bindings"); assertThat(library.latest().version()).isEqualTo(0); - assertThat(library.version(0).functions()).isEmpty(); + assertThat(library.version(0).functions().stream().map(CelFunctionDecl::name)) + .containsExactly("cel.@block"); assertThat(library.version(0).macros().stream().map(CelMacro::getFunction)) .containsExactly("bind"); } diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel b/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel index b0c48682a..734aa6879 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel @@ -33,6 +33,9 @@ java_library( "//parser:unparser", "//runtime", "//runtime:function_binding", + "//runtime:partial_vars", + "//runtime:program", + "//runtime:unknown_attributes", "//testing:baseline_test_case", "@maven//:junit_junit", "@maven//:com_google_testparameterinjector_test_parameter_injector", diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerBaselineTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerBaselineTest.java index 802ef3037..6bdde2f81 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerBaselineTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerBaselineTest.java @@ -24,6 +24,7 @@ // import com.google.testing.testsize.MediumTest; import dev.cel.bundle.Cel; import dev.cel.bundle.CelBuilder; +import dev.cel.bundle.CelExperimentalFactory; import dev.cel.bundle.CelFactory; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelContainer; @@ -43,6 +44,7 @@ import dev.cel.parser.CelUnparserFactory; import dev.cel.runtime.CelFunctionBinding; import dev.cel.testing.BaselineTestCase; +import java.util.EnumSet; import java.util.Optional; import org.junit.Before; import org.junit.Test; @@ -51,6 +53,55 @@ // @MediumTest @RunWith(TestParameterInjector.class) public class SubexpressionOptimizerBaselineTest extends BaselineTestCase { + private enum RuntimeEnv { + LEGACY(setupCelEnv(CelFactory.standardCelBuilder())), + PLANNER(setupCelEnv(CelExperimentalFactory.plannerCelBuilder())); + + private final Cel cel; + + private static Cel setupCelEnv(CelBuilder celBuilder) { + return celBuilder + .addMessageTypes(TestAllTypes.getDescriptor()) + .setContainer(CelContainer.ofName("cel.expr.conformance.proto3")) + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .setOptions( + CelOptions.current() + .populateMacroCalls(true) + .enableHeterogeneousNumericComparisons(true) + .build()) + .addCompilerLibraries( + CelExtensions.optional(), CelExtensions.bindings(), CelExtensions.comprehensions()) + .addRuntimeLibraries(CelExtensions.optional(), CelExtensions.comprehensions()) + .addFunctionDeclarations( + CelFunctionDecl.newFunctionDeclaration( + "pure_custom_func", + newGlobalOverload("pure_custom_func_overload", SimpleType.INT, SimpleType.INT)), + CelFunctionDecl.newFunctionDeclaration( + "non_pure_custom_func", + newGlobalOverload( + "non_pure_custom_func_overload", SimpleType.INT, SimpleType.INT))) + .addFunctionBindings( + // This is pure, but for the purposes of excluding it as a CSE candidate, pretend that + // it isn't. + CelFunctionBinding.fromOverloads( + "non_pure_custom_func", + CelFunctionBinding.from("non_pure_custom_func_overload", Long.class, val -> val))) + .addFunctionBindings( + CelFunctionBinding.fromOverloads( + "pure_custom_func", + CelFunctionBinding.from("pure_custom_func_overload", Long.class, val -> val))) + .addVar("x", SimpleType.DYN) + .addVar("y", SimpleType.DYN) + .addVar("opt_x", OptionalType.create(SimpleType.DYN)) + .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())) + .build(); + } + + RuntimeEnv(Cel cel) { + this.cel = cel; + } + } + private static final CelUnparser CEL_UNPARSER = CelUnparserFactory.newUnparser(); private static final TestAllTypes TEST_ALL_TYPES_INPUT = TestAllTypes.newBuilder() @@ -67,7 +118,6 @@ public class SubexpressionOptimizerBaselineTest extends BaselineTestCase { .putMapInt32Int64(2, 2) .putMapStringString("key", "A"))) .build(); - private static final Cel CEL = newCelBuilder().build(); private static final SubexpressionOptimizerOptions OPTIMIZER_COMMON_OPTIONS = SubexpressionOptimizerOptions.newBuilder() @@ -90,45 +140,70 @@ protected String baselineFileName() { return overriddenBaseFilePath; } + @TestParameter RuntimeEnv runtimeEnv; + @Test public void allOptimizers_producesSameEvaluationResult( @TestParameter CseTestOptimizer cseTestOptimizer, @TestParameter CseTestCase cseTestCase) throws Exception { skipBaselineVerification(); - CelAbstractSyntaxTree ast = CEL.compile(cseTestCase.source).getAst(); + CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(cseTestCase.source).getAst(); ImmutableMap inputMap = ImmutableMap.of("msg", TEST_ALL_TYPES_INPUT, "x", 5L, "y", 6L, "opt_x", Optional.of(5L)); - Object expectedEvalResult = CEL.createProgram(ast).eval(inputMap); + Object expectedEvalResult = runtimeEnv.cel.createProgram(ast).eval(inputMap); - CelAbstractSyntaxTree optimizedAst = cseTestOptimizer.cseOptimizer.optimize(ast); + CelAbstractSyntaxTree optimizedAst = cseTestOptimizer.newCseOptimizer(runtimeEnv).optimize(ast); - Object optimizedEvalResult = CEL.createProgram(optimizedAst).eval(inputMap); + Object optimizedEvalResult = runtimeEnv.cel.createProgram(optimizedAst).eval(inputMap); + assertThat(optimizedEvalResult).isEqualTo(expectedEvalResult); + } + + @Test + public void allOptimizers_producesSameEvaluationResult_parsedOnly( + @TestParameter CseTestCase cseTestCase, @TestParameter CseTestOptimizer cseTestOptimizer) + throws Exception { + skipBaselineVerification(); + if (runtimeEnv == RuntimeEnv.LEGACY) { + return; + } + CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(cseTestCase.source).getAst(); + ImmutableMap inputMap = + ImmutableMap.of("msg", TEST_ALL_TYPES_INPUT, "x", 5L, "y", 6L, "opt_x", Optional.of(5L)); + Object expectedEvalResult = runtimeEnv.cel.createProgram(ast).eval(inputMap); + + CelAbstractSyntaxTree optimizedAst = cseTestOptimizer.newCseOptimizer(runtimeEnv).optimize(ast); + CelAbstractSyntaxTree parsedOnlyOptimizedAst = + CelAbstractSyntaxTree.newParsedAst(optimizedAst.getExpr(), optimizedAst.getSource()); + + Object optimizedEvalResult = runtimeEnv.cel.createProgram(parsedOnlyOptimizedAst).eval(inputMap); assertThat(optimizedEvalResult).isEqualTo(expectedEvalResult); } @Test public void subexpression_unparsed() throws Exception { - for (CseTestCase cseTestCase : CseTestCase.values()) { + for (CseTestCase cseTestCase : EnumSet.allOf(CseTestCase.class)) { testOutput().println("Test case: " + cseTestCase.name()); testOutput().println("Source: " + cseTestCase.source); testOutput().println("=====>"); - CelAbstractSyntaxTree ast = CEL.compile(cseTestCase.source).getAst(); + CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(cseTestCase.source).getAst(); boolean resultPrinted = false; for (CseTestOptimizer cseTestOptimizer : CseTestOptimizer.values()) { String optimizerName = cseTestOptimizer.name(); CelAbstractSyntaxTree optimizedAst; try { - optimizedAst = cseTestOptimizer.cseOptimizer.optimize(ast); + optimizedAst = cseTestOptimizer.newCseOptimizer(runtimeEnv).optimize(ast); } catch (Exception e) { testOutput().printf("[%s]: Optimization Error: %s", optimizerName, e); continue; } if (!resultPrinted) { Object optimizedEvalResult = - CEL.createProgram(optimizedAst) + runtimeEnv + .cel + .createProgram(optimizedAst) .eval( ImmutableMap.of( - "msg", TEST_ALL_TYPES_INPUT, "x", 5L, "opt_x", Optional.of(5L))); + "msg", TEST_ALL_TYPES_INPUT, "x", 5L, "y", 6L, "opt_x", Optional.of(5L))); testOutput().println("Result: " + optimizedEvalResult); resultPrinted = true; } @@ -145,22 +220,24 @@ public void subexpression_unparsed() throws Exception { @Test public void constfold_before_subexpression_unparsed() throws Exception { - for (CseTestCase cseTestCase : CseTestCase.values()) { + for (CseTestCase cseTestCase : EnumSet.allOf(CseTestCase.class)) { testOutput().println("Test case: " + cseTestCase.name()); testOutput().println("Source: " + cseTestCase.source); testOutput().println("=====>"); - CelAbstractSyntaxTree ast = CEL.compile(cseTestCase.source).getAst(); + CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(cseTestCase.source).getAst(); boolean resultPrinted = false; - for (CseTestOptimizer cseTestOptimizer : CseTestOptimizer.values()) { + for (CseTestOptimizer cseTestOptimizer : EnumSet.allOf(CseTestOptimizer.class)) { String optimizerName = cseTestOptimizer.name(); CelAbstractSyntaxTree optimizedAst = - cseTestOptimizer.cseWithConstFoldingOptimizer.optimize(ast); + cseTestOptimizer.newCseWithConstFoldingOptimizer(runtimeEnv).optimize(ast); if (!resultPrinted) { Object optimizedEvalResult = - CEL.createProgram(optimizedAst) + runtimeEnv + .cel + .createProgram(optimizedAst) .eval( ImmutableMap.of( - "msg", TEST_ALL_TYPES_INPUT, "x", 5L, "opt_x", Optional.of(5L))); + "msg", TEST_ALL_TYPES_INPUT, "x", 5L, "y", 6L, "opt_x", Optional.of(5L))); testOutput().println("Result: " + optimizedEvalResult); resultPrinted = true; } @@ -179,12 +256,13 @@ public void constfold_before_subexpression_unparsed() throws Exception { public void subexpression_ast(@TestParameter CseTestOptimizer cseTestOptimizer) throws Exception { String testBasefileName = "subexpression_ast_" + Ascii.toLowerCase(cseTestOptimizer.name()); overriddenBaseFilePath = String.format("%s%s.baseline", testdataDir(), testBasefileName); - for (CseTestCase cseTestCase : CseTestCase.values()) { + for (CseTestCase cseTestCase : EnumSet.allOf(CseTestCase.class)) { testOutput().println("Test case: " + cseTestCase.name()); testOutput().println("Source: " + cseTestCase.source); testOutput().println("=====>"); - CelAbstractSyntaxTree ast = CEL.compile(cseTestCase.source).getAst(); - CelAbstractSyntaxTree optimizedAst = cseTestOptimizer.cseOptimizer.optimize(ast); + CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(cseTestCase.source).getAst(); + CelAbstractSyntaxTree optimizedAst = + newCseOptimizer(runtimeEnv.cel, cseTestOptimizer.option).optimize(ast); testOutput().println(optimizedAst.getExpr()); } } @@ -193,7 +271,8 @@ public void subexpression_ast(@TestParameter CseTestOptimizer cseTestOptimizer) public void large_expressions_block_common_subexpr() throws Exception { CelOptimizer celOptimizer = newCseOptimizer( - CEL, SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build()); + runtimeEnv.cel, + SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build()); runLargeTestCases(celOptimizer); } @@ -202,7 +281,7 @@ public void large_expressions_block_common_subexpr() throws Exception { public void large_expressions_block_recursion_depth_1() throws Exception { CelOptimizer celOptimizer = newCseOptimizer( - CEL, + runtimeEnv.cel, SubexpressionOptimizerOptions.newBuilder() .populateMacroCalls(true) .subexpressionMaxRecursionDepth(1) @@ -215,7 +294,7 @@ public void large_expressions_block_recursion_depth_1() throws Exception { public void large_expressions_block_recursion_depth_2() throws Exception { CelOptimizer celOptimizer = newCseOptimizer( - CEL, + runtimeEnv.cel, SubexpressionOptimizerOptions.newBuilder() .populateMacroCalls(true) .subexpressionMaxRecursionDepth(2) @@ -228,7 +307,7 @@ public void large_expressions_block_recursion_depth_2() throws Exception { public void large_expressions_block_recursion_depth_3() throws Exception { CelOptimizer celOptimizer = newCseOptimizer( - CEL, + runtimeEnv.cel, SubexpressionOptimizerOptions.newBuilder() .populateMacroCalls(true) .subexpressionMaxRecursionDepth(3) @@ -238,15 +317,16 @@ public void large_expressions_block_recursion_depth_3() throws Exception { } private void runLargeTestCases(CelOptimizer celOptimizer) throws Exception { - for (CseLargeTestCase cseTestCase : CseLargeTestCase.values()) { + for (CseLargeTestCase cseTestCase : EnumSet.allOf(CseLargeTestCase.class)) { testOutput().println("Test case: " + cseTestCase.name()); testOutput().println("Source: " + cseTestCase.source); testOutput().println("=====>"); - CelAbstractSyntaxTree ast = CEL.compile(cseTestCase.source).getAst(); - + CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(cseTestCase.source).getAst(); CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast); Object optimizedEvalResult = - CEL.createProgram(optimizedAst) + runtimeEnv + .cel + .createProgram(optimizedAst) .eval( ImmutableMap.of("msg", TEST_ALL_TYPES_INPUT, "x", 5L, "opt_x", Optional.of(5L))); testOutput().println("Result: " + optimizedEvalResult); @@ -260,33 +340,6 @@ private void runLargeTestCases(CelOptimizer celOptimizer) throws Exception { } } - private static CelBuilder newCelBuilder() { - return CelFactory.standardCelBuilder() - .addMessageTypes(TestAllTypes.getDescriptor()) - .setContainer(CelContainer.ofName("cel.expr.conformance.proto3")) - .setStandardMacros(CelStandardMacro.STANDARD_MACROS) - .setOptions(CelOptions.current().populateMacroCalls(true).build()) - .addCompilerLibraries( - CelExtensions.optional(), CelExtensions.bindings(), CelExtensions.comprehensions()) - .addRuntimeLibraries(CelExtensions.optional(), CelExtensions.comprehensions()) - .addFunctionDeclarations( - CelFunctionDecl.newFunctionDeclaration( - "pure_custom_func", - newGlobalOverload("pure_custom_func_overload", SimpleType.INT, SimpleType.INT)), - CelFunctionDecl.newFunctionDeclaration( - "non_pure_custom_func", - newGlobalOverload("non_pure_custom_func_overload", SimpleType.INT, SimpleType.INT))) - .addFunctionBindings( - // This is pure, but for the purposes of excluding it as a CSE candidate, pretend that - // it isn't. - CelFunctionBinding.from("non_pure_custom_func_overload", Long.class, val -> val), - CelFunctionBinding.from("pure_custom_func_overload", Long.class, val -> val)) - .addVar("x", SimpleType.DYN) - .addVar("y", SimpleType.DYN) - .addVar("opt_x", OptionalType.create(SimpleType.DYN)) - .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())); - } - private static CelOptimizer newCseOptimizer(Cel cel, SubexpressionOptimizerOptions options) { return CelOptimizerFactory.standardCelOptimizerBuilder(cel) .addAstOptimizers(SubexpressionOptimizer.newInstance(options)) @@ -315,17 +368,23 @@ private enum CseTestOptimizer { BLOCK_RECURSION_DEPTH_9( OPTIMIZER_COMMON_OPTIONS.toBuilder().subexpressionMaxRecursionDepth(9).build()); - private final CelOptimizer cseOptimizer; - private final CelOptimizer cseWithConstFoldingOptimizer; + private final SubexpressionOptimizerOptions option; CseTestOptimizer(SubexpressionOptimizerOptions option) { - this.cseOptimizer = newCseOptimizer(CEL, option); - this.cseWithConstFoldingOptimizer = - CelOptimizerFactory.standardCelOptimizerBuilder(CEL) - .addAstOptimizers( - ConstantFoldingOptimizer.getInstance(), - SubexpressionOptimizer.newInstance(option)) - .build(); + this.option = option; + } + + // Defers building the optimizer until the test runs + private CelOptimizer newCseOptimizer(RuntimeEnv env) { + return SubexpressionOptimizerBaselineTest.newCseOptimizer(env.cel, option); + } + + // Defers building the optimizer until the test runs + private CelOptimizer newCseWithConstFoldingOptimizer(RuntimeEnv env) { + return CelOptimizerFactory.standardCelOptimizerBuilder(env.cel) + .addAstOptimizers( + ConstantFoldingOptimizer.getInstance(), SubexpressionOptimizer.newInstance(option)) + .build(); } } diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java index 2289a7d4a..23459e5d8 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java @@ -26,6 +26,7 @@ import com.google.testing.junit.testparameterinjector.TestParameters; import dev.cel.bundle.Cel; import dev.cel.bundle.CelBuilder; +import dev.cel.bundle.CelExperimentalFactory; import dev.cel.bundle.CelFactory; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelFunctionDecl; @@ -52,10 +53,14 @@ import dev.cel.parser.CelStandardMacro; import dev.cel.parser.CelUnparser; import dev.cel.parser.CelUnparserFactory; +import dev.cel.runtime.CelAttributePattern; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.CelFunctionBinding; import dev.cel.runtime.CelRuntime; import dev.cel.runtime.CelRuntimeFactory; +import dev.cel.runtime.CelUnknownSet; +import dev.cel.runtime.PartialVars; +import dev.cel.runtime.Program; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; import org.junit.Test; @@ -64,10 +69,40 @@ @RunWith(TestParameterInjector.class) public class SubexpressionOptimizerTest { - private static final Cel CEL = newCelBuilder().build(); + private enum RuntimeEnv { + LEGACY( + setupCelEnv(CelFactory.standardCelBuilder()), + setupCelForEvaluatingBlock(CelFactory.standardCelBuilder())), + PLANNER( + setupCelEnv(CelExperimentalFactory.plannerCelBuilder()), + setupCelForEvaluatingBlock(CelExperimentalFactory.plannerCelBuilder())); - private static final Cel CEL_FOR_EVALUATING_BLOCK = - CelFactory.standardCelBuilder() + private final Cel cel; + private final Cel celForEvaluatingBlock; + + private static Cel setupCelEnv(CelBuilder celBuilder) { + return celBuilder + .addMessageTypes(TestAllTypes.getDescriptor()) + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .setOptions( + CelOptions.current() + .populateMacroCalls(true) + .enableHeterogeneousNumericComparisons(true) + .build()) + .addCompilerLibraries(CelExtensions.bindings(), CelExtensions.strings()) + .addRuntimeLibraries(CelExtensions.strings()) + .addFunctionDeclarations( + CelFunctionDecl.newFunctionDeclaration( + "non_pure_custom_func", + newGlobalOverload( + "non_pure_custom_func_overload", SimpleType.INT, SimpleType.INT))) + .addVar("x", SimpleType.DYN) + .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())) + .build(); + } + + private static Cel setupCelForEvaluatingBlock(CelBuilder celBuilder) { + return celBuilder .setStandardMacros(CelStandardMacro.STANDARD_MACROS) .addFunctionDeclarations( // These are test only declarations, as the actual function is made internal using @ @@ -98,6 +133,15 @@ public class SubexpressionOptimizerTest { .addMessageTypes(TestAllTypes.getDescriptor()) .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())) .build(); + } + + RuntimeEnv(Cel cel, Cel celForEvaluatingBlock) { + this.cel = cel; + this.celForEvaluatingBlock = celForEvaluatingBlock; + } + } + + @TestParameter RuntimeEnv runtimeEnv; private static final CelUnparser CEL_UNPARSER = CelUnparserFactory.newUnparser(); @@ -115,8 +159,8 @@ private static CelBuilder newCelBuilder() { .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())); } - private static CelOptimizer newCseOptimizer(SubexpressionOptimizerOptions options) { - return CelOptimizerFactory.standardCelOptimizerBuilder(CEL) + private CelOptimizer newCseOptimizer(SubexpressionOptimizerOptions options) { + return CelOptimizerFactory.standardCelOptimizerBuilder(runtimeEnv.cel) .addAstOptimizers(SubexpressionOptimizer.newInstance(options)) .build(); } @@ -130,15 +174,56 @@ public void cse_resultTypeSet_celBlockOptimizationSuccess() throws Exception { SubexpressionOptimizer.newInstance( SubexpressionOptimizerOptions.newBuilder().build())) .build(); - CelAbstractSyntaxTree ast = CEL.compile("size('a') + size('a') == 2").getAst(); + CelAbstractSyntaxTree ast = runtimeEnv.cel.compile("size('a') + size('a') == 2").getAst(); CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast); - assertThat(CEL.createProgram(optimizedAst).eval()).isEqualTo(true); + assertThat(runtimeEnv.cel.createProgram(optimizedAst).eval()).isEqualTo(true); assertThat(CEL_UNPARSER.unparse(optimizedAst)) .isEqualTo("cel.@block([size(\"a\")], @index0 + @index0 == 2)"); } + @Test + public void cse_indexEvaluationErrors_throws() throws Exception { + CelAbstractSyntaxTree ast = + runtimeEnv.cel.compile("\"abc\".charAt(10) + \"abc\".charAt(10)").getAst(); + CelOptimizer optimizedOptimizer = + CelOptimizerFactory.standardCelOptimizerBuilder(runtimeEnv.cel) + .addAstOptimizers(SubexpressionOptimizer.getInstance()) + .build(); + + CelAbstractSyntaxTree optimizedAst = optimizedOptimizer.optimize(ast); + + String unparsed = CEL_UNPARSER.unparse(optimizedAst); + assertThat(unparsed).isEqualTo("cel.@block([\"abc\".charAt(10)], @index0 + @index0)"); + + Program program = runtimeEnv.cel.createProgram(optimizedAst); + CelEvaluationException e = + assertThrows(CelEvaluationException.class, () -> program.eval(ImmutableMap.of())); + assertThat(e).hasMessageThat().contains("charAt failure: Index out of range: 10"); + } + + @Test + public void cse_withUnknownAttributes() throws Exception { + CelAbstractSyntaxTree ast = runtimeEnv.cel.compile("size(\"a\") == 1 ? x.y : x.y").getAst(); + CelOptimizer optimizer = + CelOptimizerFactory.standardCelOptimizerBuilder(runtimeEnv.cel) + .addAstOptimizers(SubexpressionOptimizer.getInstance()) + .build(); + + CelAbstractSyntaxTree optimizedAst = optimizer.optimize(ast); + + assertThat(CEL_UNPARSER.unparse(optimizedAst)) + .isEqualTo("cel.@block([x.y], (size(\"a\") == 1) ? @index0 : @index0)"); + + Object result = + runtimeEnv + .cel + .createProgram(optimizedAst) + .eval(PartialVars.of(CelAttributePattern.fromQualifiedIdentifier("x"))); + assertThat(result).isInstanceOf(CelUnknownSet.class); + } + private enum CseNoOpTestCase { // Nothing to optimize NO_COMMON_SUBEXPR("size(\"hello\")"), @@ -169,7 +254,7 @@ private enum CseNoOpTestCase { @Test public void cse_withCelBind_noop(@TestParameter CseNoOpTestCase testCase) throws Exception { - CelAbstractSyntaxTree ast = CEL.compile(testCase.source).getAst(); + CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(testCase.source).getAst(); CelAbstractSyntaxTree optimizedAst = newCseOptimizer(SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build()) @@ -181,7 +266,7 @@ public void cse_withCelBind_noop(@TestParameter CseNoOpTestCase testCase) throws @Test public void cse_withCelBlock_noop(@TestParameter CseNoOpTestCase testCase) throws Exception { - CelAbstractSyntaxTree ast = CEL.compile(testCase.source).getAst(); + CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(testCase.source).getAst(); CelAbstractSyntaxTree optimizedAst = newCseOptimizer(SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build()) @@ -194,7 +279,7 @@ public void cse_withCelBlock_noop(@TestParameter CseNoOpTestCase testCase) throw @Test public void cse_withComprehensionStructureRetained() throws Exception { CelAbstractSyntaxTree ast = - CEL.compile("['foo'].map(x, [x+x]) + ['foo'].map(x, [x+x, x+x])").getAst(); + runtimeEnv.cel.compile("['foo'].map(x, [x+x]) + ['foo'].map(x, [x+x, x+x])").getAst(); CelOptimizer celOptimizer = newCseOptimizer( SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build()); @@ -210,10 +295,12 @@ public void cse_withComprehensionStructureRetained() throws Exception { @Test public void cse_applyConstFoldingBefore() throws Exception { CelAbstractSyntaxTree ast = - CEL.compile("size([1+1+1]) + size([1+1+1]) + size([1,1+1+1]) + size([1,1+1+1]) + x") + runtimeEnv + .cel + .compile("size([1+1+1]) + size([1+1+1]) + size([1,1+1+1]) + size([1,1+1+1]) + x") .getAst(); CelOptimizer optimizer = - CelOptimizerFactory.standardCelOptimizerBuilder(CEL) + CelOptimizerFactory.standardCelOptimizerBuilder(runtimeEnv.cel) .addAstOptimizers( ConstantFoldingOptimizer.getInstance(), SubexpressionOptimizer.newInstance( @@ -228,10 +315,12 @@ public void cse_applyConstFoldingBefore() throws Exception { @Test public void cse_applyConstFoldingAfter() throws Exception { CelAbstractSyntaxTree ast = - CEL.compile("size([1+1+1]) + size([1+1+1]) + size([1,1+1+1]) + size([1,1+1+1]) + x") + runtimeEnv + .cel + .compile("size([1+1+1]) + size([1+1+1]) + size([1,1+1+1]) + size([1,1+1+1]) + x") .getAst(); CelOptimizer optimizer = - CelOptimizerFactory.standardCelOptimizerBuilder(CEL) + CelOptimizerFactory.standardCelOptimizerBuilder(runtimeEnv.cel) .addAstOptimizers( SubexpressionOptimizer.newInstance( SubexpressionOptimizerOptions.newBuilder().build()), @@ -246,9 +335,9 @@ public void cse_applyConstFoldingAfter() throws Exception { @Test public void cse_applyConstFoldingAfter_nothingToFold() throws Exception { - CelAbstractSyntaxTree ast = CEL.compile("size(x) + size(x)").getAst(); + CelAbstractSyntaxTree ast = runtimeEnv.cel.compile("size(x) + size(x)").getAst(); CelOptimizer optimizer = - CelOptimizerFactory.standardCelOptimizerBuilder(CEL) + CelOptimizerFactory.standardCelOptimizerBuilder(runtimeEnv.cel) .addAstOptimizers( SubexpressionOptimizer.newInstance( SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build()), @@ -271,7 +360,7 @@ public void iterationLimitReached_throws() throws Exception { largeExprBuilder.append("+"); } } - CelAbstractSyntaxTree ast = CEL.compile(largeExprBuilder.toString()).getAst(); + CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(largeExprBuilder.toString()).getAst(); CelOptimizationException e = assertThrows( @@ -287,9 +376,9 @@ public void iterationLimitReached_throws() throws Exception { @Test public void celBlock_astExtensionTagged() throws Exception { - CelAbstractSyntaxTree ast = CEL.compile("size(x) + size(x)").getAst(); + CelAbstractSyntaxTree ast = runtimeEnv.cel.compile("size(x) + size(x)").getAst(); CelOptimizer optimizer = - CelOptimizerFactory.standardCelOptimizerBuilder(CEL) + CelOptimizerFactory.standardCelOptimizerBuilder(runtimeEnv.cel) .addAstOptimizers( SubexpressionOptimizer.newInstance( SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build()), @@ -322,7 +411,20 @@ private enum BlockTestCase { public void block_success(@TestParameter BlockTestCase testCase) throws Exception { CelAbstractSyntaxTree ast = compileUsingInternalFunctions(testCase.source); - Object evaluatedResult = CEL_FOR_EVALUATING_BLOCK.createProgram(ast).eval(); + Object evaluatedResult = runtimeEnv.celForEvaluatingBlock.createProgram(ast).eval(); + + assertThat(evaluatedResult).isNotNull(); + } + + @Test + public void block_success_parsedOnly(@TestParameter BlockTestCase testCase) throws Exception { + if (runtimeEnv == RuntimeEnv.LEGACY) { + return; + } + CelAbstractSyntaxTree ast = + compileUsingInternalFunctions(testCase.source, /* parsedOnly= */ true); + + Object evaluatedResult = runtimeEnv.celForEvaluatingBlock.createProgram(ast).eval(); assertThat(evaluatedResult).isNotNull(); } @@ -584,7 +686,7 @@ public void block_containsCycle_throws() throws Exception { CelAbstractSyntaxTree ast = compileUsingInternalFunctions("cel.block([index1,index0],index0)"); CelEvaluationException e = - assertThrows(CelEvaluationException.class, () -> CEL.createProgram(ast).eval()); + assertThrows(CelEvaluationException.class, () -> runtimeEnv.cel.createProgram(ast).eval()); assertThat(e).hasMessageThat().contains("Cycle detected: @index0"); } @@ -595,7 +697,7 @@ public void block_lazyEvaluationContainsError_cleansUpCycleState() throws Except "cel.block([1/0 > 0], (index0 && false) || (index0 && true))"); CelEvaluationException e = - assertThrows(CelEvaluationException.class, () -> CEL.createProgram(ast).eval()); + assertThrows(CelEvaluationException.class, () -> runtimeEnv.cel.createProgram(ast).eval()); assertThat(e).hasMessageThat().contains("/ by zero"); assertThat(e).hasMessageThat().doesNotContain("Cycle detected"); @@ -605,9 +707,10 @@ public void block_lazyEvaluationContainsError_cleansUpCycleState() throws Except * Converts AST containing cel.block related test functions to internal functions (e.g: cel.block * -> cel.@block) */ - private static CelAbstractSyntaxTree compileUsingInternalFunctions(String expression) + private CelAbstractSyntaxTree compileUsingInternalFunctions(String expression, boolean parsedOnly) throws CelValidationException { - CelAbstractSyntaxTree astToModify = CEL_FOR_EVALUATING_BLOCK.compile(expression).getAst(); + CelAbstractSyntaxTree astToModify = + runtimeEnv.celForEvaluatingBlock.compile(expression).getAst(); CelMutableAst mutableAst = CelMutableAst.fromCelAst(astToModify); CelNavigableMutableAst.fromAst(mutableAst) .getRoot() @@ -629,6 +732,14 @@ private static CelAbstractSyntaxTree compileUsingInternalFunctions(String expres indexExpr.ident().setName(internalIdentName); }); - return CEL_FOR_EVALUATING_BLOCK.check(mutableAst.toParsedAst()).getAst(); + if (parsedOnly) { + return mutableAst.toParsedAst(); + } + return runtimeEnv.celForEvaluatingBlock.check(mutableAst.toParsedAst()).getAst(); + } + + private CelAbstractSyntaxTree compileUsingInternalFunctions(String expression) + throws CelValidationException { + return compileUsingInternalFunctions(expression, false); } } diff --git a/optimizer/src/test/resources/constfold_before_subexpression_unparsed.baseline b/optimizer/src/test/resources/constfold_before_subexpression_unparsed.baseline index 55da856cd..9139c7a35 100644 --- a/optimizer/src/test/resources/constfold_before_subexpression_unparsed.baseline +++ b/optimizer/src/test/resources/constfold_before_subexpression_unparsed.baseline @@ -526,7 +526,7 @@ Result: [[[foofoo, foofoo, foofoo, foofoo], [foofoo, foofoo, foofoo, foofoo]], [ Test case: MACRO_SHADOWED_VARIABLE_COMP_V2_1 Source: [x - y - 1 > 3 ? x - y - 1 : 5].exists(x, y, x - y - 1 > 3) || x - y - 1 > 3 =====> -Result: CelUnknownSet{attributes=[], unknownExprIds=[6]} +Result: false [BLOCK_COMMON_SUBEXPR_ONLY]: cel.@block([x - y - 1, @index0 > 3], [@index1 ? @index0 : 5].exists(@it:0:0, @it2:0:0, @it:0:0 - @it2:0:0 - 1 > 3) || @index1) [BLOCK_RECURSION_DEPTH_1]: cel.@block([x - y, @index0 - 1, @index1 > 3, @index2 ? @index1 : 5, [@index3]], @index4.exists(@it:0:0, @it2:0:0, @it:0:0 - @it2:0:0 - 1 > 3) || @index2) [BLOCK_RECURSION_DEPTH_2]: cel.@block([x - y - 1, @index0 > 3, [@index1 ? @index0 : 5]], @index2.exists(@it:0:0, @it2:0:0, @it:0:0 - @it2:0:0 - 1 > 3) || @index1) diff --git a/optimizer/src/test/resources/subexpression_unparsed.baseline b/optimizer/src/test/resources/subexpression_unparsed.baseline index e0edc8987..780664a14 100644 --- a/optimizer/src/test/resources/subexpression_unparsed.baseline +++ b/optimizer/src/test/resources/subexpression_unparsed.baseline @@ -526,7 +526,7 @@ Result: [[[foofoo, foofoo, foofoo, foofoo], [foofoo, foofoo, foofoo, foofoo]], [ Test case: MACRO_SHADOWED_VARIABLE_COMP_V2_1 Source: [x - y - 1 > 3 ? x - y - 1 : 5].exists(x, y, x - y - 1 > 3) || x - y - 1 > 3 =====> -Result: CelUnknownSet{attributes=[], unknownExprIds=[6]} +Result: false [BLOCK_COMMON_SUBEXPR_ONLY]: cel.@block([x - y - 1, @index0 > 3], [@index1 ? @index0 : 5].exists(@it:0:0, @it2:0:0, @it:0:0 - @it2:0:0 - 1 > 3) || @index1) [BLOCK_RECURSION_DEPTH_1]: cel.@block([x - y, @index0 - 1, @index1 > 3, @index2 ? @index1 : 5, [@index3]], @index4.exists(@it:0:0, @it2:0:0, @it:0:0 - @it2:0:0 - 1 > 3) || @index2) [BLOCK_RECURSION_DEPTH_2]: cel.@block([x - y - 1, @index0 > 3, [@index1 ? @index0 : 5]], @index2.exists(@it:0:0, @it2:0:0, @it:0:0 - @it2:0:0 - 1 > 3) || @index1) diff --git a/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel index fc70118e4..cb2ad5a82 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel +++ b/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel @@ -18,6 +18,7 @@ java_library( ":eval_and", ":eval_attribute", ":eval_binary", + ":eval_block", ":eval_conditional", ":eval_const", ":eval_create_list", @@ -67,7 +68,6 @@ java_library( srcs = ["PlannedProgram.java"], deps = [ ":error_metadata", - ":execution_frame", ":localized_evaluation_exception", ":planned_interpretable", "//:auto_value", @@ -92,11 +92,9 @@ java_library( name = "eval_const", srcs = ["EvalConstant.java"], deps = [ - ":execution_frame", ":planned_interpretable", "//runtime:interpretable", "@maven//:com_google_errorprone_error_prone_annotations", - "@maven//:com_google_guava_guava", ], ) @@ -123,7 +121,6 @@ java_library( deps = [ ":activation_wrapper", ":eval_helpers", - ":execution_frame", ":planned_interpretable", ":qualifier", "//common:container", @@ -183,8 +180,8 @@ java_library( srcs = ["EvalAttribute.java"], deps = [ ":attribute", - ":execution_frame", ":interpretable_attribute", + ":planned_interpretable", ":qualifier", "//runtime:interpretable", "@maven//:com_google_errorprone_error_prone_annotations", @@ -195,8 +192,8 @@ java_library( name = "eval_test_only", srcs = ["EvalTestOnly.java"], deps = [ - ":execution_frame", ":interpretable_attribute", + ":planned_interpretable", ":presence_test_qualifier", ":qualifier", "//runtime:evaluation_exception", @@ -210,7 +207,6 @@ java_library( srcs = ["EvalZeroArity.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/values", "//runtime:evaluation_exception", @@ -224,7 +220,6 @@ java_library( srcs = ["EvalUnary.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/values", "//runtime:evaluation_exception", @@ -238,7 +233,6 @@ java_library( srcs = ["EvalBinary.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/values", "//runtime:accumulated_unknowns", @@ -253,7 +247,6 @@ java_library( srcs = ["EvalVarArgsCall.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/values", "//runtime:accumulated_unknowns", @@ -268,7 +261,6 @@ java_library( srcs = ["EvalLateBoundCall.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/exceptions:overload_not_found", "//common/values", @@ -285,7 +277,6 @@ java_library( srcs = ["EvalOr.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/values", "//runtime:accumulated_unknowns", @@ -299,7 +290,6 @@ java_library( srcs = ["EvalAnd.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/values", "//runtime:accumulated_unknowns", @@ -312,7 +302,6 @@ java_library( name = "eval_conditional", srcs = ["EvalConditional.java"], deps = [ - ":execution_frame", ":planned_interpretable", "//runtime:accumulated_unknowns", "//runtime:evaluation_exception", @@ -326,7 +315,6 @@ java_library( srcs = ["EvalCreateStruct.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/types:type_providers", "//common/values", @@ -344,7 +332,6 @@ java_library( srcs = ["EvalCreateList.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//runtime:accumulated_unknowns", "//runtime:evaluation_exception", @@ -359,7 +346,6 @@ java_library( srcs = ["EvalCreateMap.java"], deps = [ ":eval_helpers", - ":execution_frame", ":localized_evaluation_exception", ":planned_interpretable", "//common/exceptions:duplicate_key", @@ -377,7 +363,6 @@ java_library( srcs = ["EvalFold.java"], deps = [ ":activation_wrapper", - ":execution_frame", ":planned_interpretable", "//runtime:accumulated_unknowns", "//runtime:concatenated_list_view", @@ -389,24 +374,10 @@ java_library( ], ) -java_library( - name = "execution_frame", - srcs = ["ExecutionFrame.java"], - deps = [ - "//common:options", - "//common/exceptions:iteration_budget_exceeded", - "//runtime:evaluation_exception", - "//runtime:function_resolver", - "//runtime:partial_vars", - "//runtime:resolved_overload", - ], -) - java_library( name = "eval_helpers", srcs = ["EvalHelpers.java"], deps = [ - ":execution_frame", ":localized_evaluation_exception", ":planned_interpretable", "//common:error_codes", @@ -440,11 +411,20 @@ java_library( java_library( name = "planned_interpretable", - srcs = ["PlannedInterpretable.java"], + srcs = [ + "BlockMemoizer.java", + "ExecutionFrame.java", + "PlannedInterpretable.java", + ], deps = [ - ":execution_frame", + ":localized_evaluation_exception", + "//common:options", + "//common/exceptions:iteration_budget_exceeded", "//runtime:evaluation_exception", + "//runtime:function_resolver", "//runtime:interpretable", + "//runtime:partial_vars", + "//runtime:resolved_overload", "@maven//:com_google_errorprone_error_prone_annotations", ], ) @@ -454,7 +434,6 @@ java_library( srcs = ["EvalOptionalOr.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/exceptions:overload_not_found", "//runtime:accumulated_unknowns", @@ -469,7 +448,6 @@ java_library( srcs = ["EvalOptionalOrValue.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/exceptions:overload_not_found", "//runtime:accumulated_unknowns", @@ -484,7 +462,6 @@ java_library( srcs = ["EvalOptionalSelectField.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/values", "//runtime:accumulated_unknowns", @@ -493,3 +470,14 @@ java_library( "@maven//:com_google_guava_guava", ], ) + +java_library( + name = "eval_block", + srcs = ["EvalBlock.java"], + deps = [ + ":planned_interpretable", + "//runtime:evaluation_exception", + "//runtime:interpretable", + "@maven//:com_google_errorprone_error_prone_annotations", + ], +) diff --git a/runtime/src/main/java/dev/cel/runtime/planner/BlockMemoizer.java b/runtime/src/main/java/dev/cel/runtime/planner/BlockMemoizer.java new file mode 100644 index 000000000..978029b3d --- /dev/null +++ b/runtime/src/main/java/dev/cel/runtime/planner/BlockMemoizer.java @@ -0,0 +1,72 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.runtime.planner; + +import dev.cel.runtime.CelEvaluationException; +import dev.cel.runtime.GlobalResolver; +import java.util.Arrays; + +/** Handles memoization, lazy evaluation, and cycle detection for cel.@block slots. */ +final class BlockMemoizer { + + private static final Object IN_PROGRESS = new Object(); + private static final Object UNSET = new Object(); + + private final PlannedInterpretable[] slotExprs; + private final Object[] slotVals; + private final ExecutionFrame frame; + + static BlockMemoizer create(PlannedInterpretable[] slotExprs, ExecutionFrame frame) { + return new BlockMemoizer(slotExprs, frame); + } + + private BlockMemoizer(PlannedInterpretable[] slotExprs, ExecutionFrame frame) { + this.slotExprs = slotExprs; + this.frame = frame; + this.slotVals = new Object[slotExprs.length]; + Arrays.fill(this.slotVals, UNSET); + } + + Object resolveSlot(int idx, GlobalResolver resolver) { + Object val = slotVals[idx]; + + // Already evaluated + if (val != UNSET && val != IN_PROGRESS) { + if (val instanceof RuntimeException) { + throw (RuntimeException) val; + } + return val; + } + + if (val == IN_PROGRESS) { + throw new IllegalStateException("Cycle detected: @index" + idx); + } + + slotVals[idx] = IN_PROGRESS; + try { + Object result = slotExprs[idx].eval(resolver, frame); + slotVals[idx] = result; + return result; + } catch (CelEvaluationException e) { + LocalizedEvaluationException localizedException = + new LocalizedEvaluationException(e, e.getErrorCode(), slotExprs[idx].exprId()); + slotVals[idx] = localizedException; + throw localizedException; + } catch (RuntimeException e) { + slotVals[idx] = e; + throw e; + } + } +} diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalBlock.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalBlock.java new file mode 100644 index 000000000..41ad4034e --- /dev/null +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalBlock.java @@ -0,0 +1,67 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.runtime.planner; + +import com.google.errorprone.annotations.Immutable; +import dev.cel.runtime.CelEvaluationException; +import dev.cel.runtime.GlobalResolver; + +/** Eval implementation of {@code cel.@block}. */ +@Immutable +final class EvalBlock extends PlannedInterpretable { + + @SuppressWarnings("Immutable") // Array not mutated after creation + private final PlannedInterpretable[] slotExprs; + + private final PlannedInterpretable resultExpr; + + static EvalBlock create( + long exprId, PlannedInterpretable[] slotExprs, PlannedInterpretable resultExpr) { + return new EvalBlock(exprId, slotExprs, resultExpr); + } + + private EvalBlock( + long exprId, PlannedInterpretable[] slotExprs, PlannedInterpretable resultExpr) { + super(exprId); + this.slotExprs = slotExprs; + this.resultExpr = resultExpr; + } + + @Override + public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { + BlockMemoizer memoizer = BlockMemoizer.create(slotExprs, frame); + frame.setBlockMemoizer(memoizer); + return resultExpr.eval(resolver, frame); + } + + @Immutable + static final class EvalBlockSlot extends PlannedInterpretable { + private final int slotIndex; + + static EvalBlockSlot create(long exprId, int slotIndex) { + return new EvalBlockSlot(exprId, slotIndex); + } + + private EvalBlockSlot(long exprId, int slotIndex) { + super(exprId); + this.slotIndex = slotIndex; + } + + @Override + public Object eval(GlobalResolver resolver, ExecutionFrame frame) { + return frame.getBlockMemoizer().resolveSlot(slotIndex, resolver); + } + } +} diff --git a/runtime/src/main/java/dev/cel/runtime/planner/ExecutionFrame.java b/runtime/src/main/java/dev/cel/runtime/planner/ExecutionFrame.java index e29c68dd8..282b7c83a 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/ExecutionFrame.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/ExecutionFrame.java @@ -30,6 +30,7 @@ final class ExecutionFrame { private final CelFunctionResolver functionResolver; private final PartialVars partialVars; private int iterationCount; + private BlockMemoizer blockMemoizer; Optional findOverload( String functionName, Collection overloadIds, Object[] args) @@ -49,6 +50,17 @@ void incrementIterations() { } } + void setBlockMemoizer(BlockMemoizer blockMemoizer) { + if (this.blockMemoizer != null) { + throw new IllegalStateException("BlockMemoizer is already initialized"); + } + this.blockMemoizer = blockMemoizer; + } + + BlockMemoizer getBlockMemoizer() { + return blockMemoizer; + } + static ExecutionFrame create( CelFunctionResolver functionResolver, PartialVars partialVars, CelOptions celOptions) { return new ExecutionFrame( diff --git a/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java b/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java index add918f64..9bd5f3ecd 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java @@ -164,6 +164,11 @@ private PlannedInterpretable planIdent(CelExpr celExpr, PlannerContext ctx) { } String identName = celExpr.ident().name(); + PlannedInterpretable blockSlot = maybeInterceptBlockSlot(celExpr.id(), identName).orElse(null); + if (blockSlot != null) { + return blockSlot; + } + if (ctx.isLocalVar(identName)) { return EvalAttribute.create(celExpr.id(), attributeFactory.newAbsoluteAttribute(identName)); } @@ -196,11 +201,42 @@ private PlannedInterpretable planCheckedIdent( return EvalConstant.create(id, identType); } + String identName = identRef.name(); + PlannedInterpretable blockSlot = maybeInterceptBlockSlot(id, identName).orElse(null); + if (blockSlot != null) { + return blockSlot; + } + return EvalAttribute.create(id, attributeFactory.newAbsoluteAttribute(identRef.name())); } + private Optional maybeInterceptBlockSlot(long id, String identName) { + if (!identName.startsWith("@index")) { + return Optional.empty(); + } + if (identName.length() <= 6) { + throw new IllegalArgumentException("Malformed block slot identifier: " + identName); + } + try { + int slotIndex = Integer.parseInt(identName.substring(6)); + if (slotIndex < 0) { + throw new IllegalArgumentException("Negative block slot index: " + identName); + } + return Optional.of(EvalBlock.EvalBlockSlot.create(id, slotIndex)); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid block slot index: " + identName, e); + } + } + private PlannedInterpretable planCall(CelExpr expr, PlannerContext ctx) { ResolvedFunction resolvedFunction = resolveFunction(expr, ctx.referenceMap()); + String functionName = resolvedFunction.functionName(); + + PlannedInterpretable blockCall = maybeInterceptBlockCall(functionName, expr, ctx).orElse(null); + if (blockCall != null) { + return blockCall; + } + CelExpr target = resolvedFunction.target().orElse(null); int argCount = expr.call().args().size(); if (target != null) { @@ -220,7 +256,6 @@ private PlannedInterpretable planCall(CelExpr expr, PlannerContext ctx) { evaluatedArgs[argIndex + offset] = plan(args.get(argIndex), ctx); } - String functionName = resolvedFunction.functionName(); Operator operator = Operator.findReverse(functionName).orElse(null); if (operator != null) { switch (operator) { @@ -285,6 +320,28 @@ private PlannedInterpretable planCall(CelExpr expr, PlannerContext ctx) { } } + private Optional maybeInterceptBlockCall( + String functionName, CelExpr expr, PlannerContext ctx) { + if (!functionName.equals("cel.@block")) { + return Optional.empty(); + } + + CelCall blockCall = expr.call(); + + if (blockCall.args().size() != 2) { + throw new IllegalArgumentException( + "Expected 2 arguments for cel.@block call. Got: " + blockCall.args().size()); + } + + CelList exprList = blockCall.args().get(0).list(); + PlannedInterpretable[] slotExprs = new PlannedInterpretable[exprList.elements().size()]; + for (int i = 0; i < slotExprs.length; i++) { + slotExprs[i] = plan(exprList.elements().get(i), ctx); + } + PlannedInterpretable resultExpr = plan(blockCall.args().get(1), ctx); + return Optional.of(EvalBlock.create(expr.id(), slotExprs, resultExpr)); + } + /** * Intercepts a potential optional function call. *