From 4af848b437a79655cc6b4638d6dfcc68e31e330d Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Tue, 23 Jun 2026 12:00:06 -0700 Subject: [PATCH] Support pruning aggregate literals in optionals PiperOrigin-RevId: 936809712 --- .../optimizers/ConstantFoldingOptimizer.java | 57 +++++++++++++++---- .../ConstantFoldingOptimizerTest.java | 20 ++++++- 2 files changed, 63 insertions(+), 14 deletions(-) diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java index 46f801bb8..35d181905 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java @@ -58,7 +58,6 @@ import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Map.Entry; import java.util.Optional; /** @@ -309,8 +308,12 @@ private Optional maybeFold( return maybeRewriteOptional(optResult, mutableAst, node.expr()); } - return maybeAdaptEvaluatedResult(result) - .map(celExpr -> astMutator.replaceSubtree(mutableAst, celExpr, node.id())); + CelMutableExpr adaptedResult = maybeAdaptEvaluatedResult(result).orElse(null); + if (adaptedResult == null) { + return Optional.empty(); + } + + return Optional.of(astMutator.replaceSubtree(mutableAst, adaptedResult, node.id())); } private Optional maybeAdaptEvaluatedResult(Object result) { @@ -331,7 +334,7 @@ private Optional maybeAdaptEvaluatedResult(Object result) { } else if (result instanceof Map) { Map map = (Map) result; List mapEntries = new ArrayList<>(); - for (Entry entry : map.entrySet()) { + for (Map.Entry entry : map.entrySet()) { CelMutableExpr adaptedKey = maybeAdaptEvaluatedResult(entry.getKey()).orElse(null); if (adaptedKey == null) { return Optional.empty(); @@ -384,16 +387,15 @@ private Optional maybeRewriteOptional( return Optional.empty(); } - if (!CelConstant.isConstantValue(unwrappedResult)) { - // Evaluated result is not a constant. Leave the optional as is. + CelMutableExpr adaptedResult = maybeAdaptEvaluatedResult(unwrappedResult).orElse(null); + if (adaptedResult == null) { + // Evaluated result is not an adaptable constant. Leave the optional as is. return Optional.empty(); } CelMutableExpr newOptionalOfCall = CelMutableExpr.ofCall( - CelMutableCall.create( - Function.OPTIONAL_OF.getFunction(), - CelMutableExpr.ofConstant(CelConstant.ofObjectValue(unwrappedResult)))); + CelMutableCall.create(Function.OPTIONAL_OF.getFunction(), adaptedResult)); return Optional.of(astMutator.replaceSubtree(mutableAst, newOptionalOfCall, expr.id())); } @@ -530,6 +532,37 @@ private Optional maybeShortCircuitCall( "Folding variadic logical operator is not supported yet."); } + private boolean isFoldedAggregateLiteral(CelMutableExpr expr) { + if (expr.getKind().equals(Kind.CONSTANT)) { + return true; + } + if (expr.getKind().equals(Kind.LIST)) { + for (CelMutableExpr child : expr.list().elements()) { + if (!isFoldedAggregateLiteral(child)) { + return false; + } + } + return true; + } + if (expr.getKind().equals(Kind.MAP)) { + for (CelMutableExpr.CelMutableMap.Entry entry : expr.map().entries()) { + if (!isFoldedAggregateLiteral(entry.key()) || !isFoldedAggregateLiteral(entry.value())) { + return false; + } + } + return true; + } + if (expr.getKind().equals(Kind.STRUCT)) { + for (CelMutableExpr.CelMutableStruct.Entry entry : expr.struct().entries()) { + if (!isFoldedAggregateLiteral(entry.value())) { + return false; + } + } + return true; + } + return false; + } + private CelMutableAst pruneOptionalElements(CelMutableAst ast) { ImmutableList aggregateLiterals = CelNavigableMutableExpr.fromExpr(ast.expr()) @@ -588,7 +621,7 @@ private CelMutableAst pruneOptionalListElements(CelMutableAst mutableAst, CelMut continue; } else if (call.function().equals(Function.OPTIONAL_OF.getFunction())) { CelMutableExpr arg = call.args().get(0); - if (arg.getKind().equals(Kind.CONSTANT)) { + if (isFoldedAggregateLiteral(arg)) { updatedElemBuilder.add(call.args().get(0)); continue; } @@ -629,7 +662,7 @@ private CelMutableAst pruneOptionalMapElements(CelMutableAst ast, CelMutableExpr continue; } else if (call.function().equals(Function.OPTIONAL_OF.getFunction())) { CelMutableExpr arg = call.args().get(0); - if (arg.getKind().equals(Kind.CONSTANT)) { + if (isFoldedAggregateLiteral(arg)) { modified = true; entry.setOptionalEntry(false); entry.setValue(call.args().get(0)); @@ -670,7 +703,7 @@ private CelMutableAst pruneOptionalStructElements(CelMutableAst ast, CelMutableE continue; } else if (call.function().equals(Function.OPTIONAL_OF.getFunction())) { CelMutableExpr arg = call.args().get(0); - if (arg.getKind().equals(Kind.CONSTANT)) { + if (isFoldedAggregateLiteral(arg)) { modified = true; entry.setOptionalEntry(false); entry.setValue(call.args().get(0)); diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java index 66f5a94d7..ee62c5ed1 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java @@ -255,6 +255,24 @@ private static Cel setupEnv(CelBuilder celBuilder) { @TestParameters( "{source: 'has({\"req\": \"Avail\"}.opt) ? ({\"req\": \"Avail\"}.req + \" \" +" + " {\"req\": \"Avail\"}.opt) : {\"req\": \"Avail\"}.req', expected: '\"Avail\"'}") + @TestParameters("{source: 'true || optional.none().hasValue()', expected: 'true'}") + @TestParameters("{source: 'false && map_var[?\"missing\"].hasValue()', expected: 'false'}") + @TestParameters("{source: '{\"hello\": [1, 2]}.?hello', expected: 'optional.of([1, 2])'}") + @TestParameters( + "{source: '{?\"key\": optional.of({\"a\": 1})}', expected: '{\"key\": {\"a\": 1}}'}") + @TestParameters( + "{source: 'TestAllTypes{?repeated_int32: optional.of([1, 2])}'," + + " expected: 'cel.expr.conformance.proto3.TestAllTypes{repeated_int32: [1, 2]}'}") + @TestParameters("{source: '[?optional.of([1, x])]', expected: '[?optional.of([1, x])]'}") + @TestParameters("{source: '[?optional.of({\"a\": x})]', expected: '[?optional.of({\"a\": x})]'}") + @TestParameters("{source: '[?optional.of({x: 1})]', expected: '[?optional.of({x: 1})]'}") + @TestParameters( + "{source: '[?optional.of(TestAllTypes{single_int32: x})]', expected:" + + " '[?optional.of(cel.expr.conformance.proto3.TestAllTypes{single_int32: x})]'}") + @TestParameters( + "{source: '[?optional.of(TestAllTypes{single_int32: 1})]', expected:" + + " '[cel.expr.conformance.proto3.TestAllTypes{single_int32: 1}]'}") + @TestParameters("{source: '[?optional.of(x)]', expected: '[?optional.of(x)]'}") // TODO: Support folding lists with mixed types. This requires mutable lists. // @TestParameters("{source: 'dyn([1]) + [1.0]'}") public void constantFold_success(String source, String expected) throws Exception { @@ -560,6 +578,4 @@ public void iterationLimitReached_throws() throws Exception { assertThrows(CelOptimizationException.class, () -> optimizer.optimize(ast)); assertThat(e).hasMessageThat().contains("Optimization failure: Max iteration count reached."); } - - }