From c261cb8b059c02fa72ae242ec164e57d6196c712 Mon Sep 17 00:00:00 2001 From: matthew brian white Date: Fri, 26 Jun 2026 11:39:53 +0100 Subject: [PATCH 1/4] fix: calcite optimization adds literalagg Signed-off-by: matthew brian white --- .../isthmus/SubstraitRelVisitor.java | 68 ++++++++++++++++++- .../isthmus/OptimizerIntegrationTest.java | 66 ++++++++++++++++++ .../substrait/isthmus/Substrait2SqlTest.java | 1 + 3 files changed, 132 insertions(+), 3 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java index bc4d0a9f2..e7106ac30 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java @@ -57,6 +57,8 @@ import org.apache.calcite.rex.RexFieldAccess; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; +import com.google.common.collect.Iterables; +import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.util.ImmutableBitSet; import org.immutables.value.Value; @@ -348,10 +350,17 @@ public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) { .filter(c -> c.getAggregation().equals(SqlStdOperatorTable.GROUP_ID)) .collect(Collectors.toList()); + // get LITERAL_AGG() function calls — injected by SubQueryRemoveRule (CALCITE-6945) as a + // null-presence indicator; they carry a RexLiteral in rexList and have no Substrait binding. + List literalAggCalls = + aggregate.getAggCallList().stream() + .filter(c -> c.getAggregation().getKind() == SqlKind.LITERAL_AGG) + .collect(Collectors.toList()); + List filteredAggCalls = aggregate.getAggCallList().stream() - // remove GROUP_ID() function calls - .filter(c -> !groupIdCalls.contains(c)) + // remove GROUP_ID() and LITERAL_AGG() function calls + .filter(c -> !groupIdCalls.contains(c) && !literalAggCalls.contains(c)) .collect(Collectors.toList()); List aggCalls = @@ -388,6 +397,8 @@ public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) { i + groupingFieldCount, filteredAggCalls.indexOf(aggCall) + groupingFieldCount); } else if (groupIdCalls.contains(aggCall)) { remap.add(i + groupingFieldCount, groupingSetIndex); + } else if (literalAggCalls.contains(aggCall)) { + // LITERAL_AGG handled below via Project wrapper — skip remap slot for now } else { // this should never get triggered throw new IllegalStateException( @@ -400,7 +411,58 @@ public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) { } } - return builder.build(); + Rel aggRel = builder.build(); + + if (literalAggCalls.isEmpty()) { + return aggRel; + } + + if (groupings.size() > 1) { + throw new UnsupportedOperationException( + "LITERAL_AGG combined with GROUPING SETS / CUBE / ROLLUP is not supported"); + } + + // Wrap the aggregate in a Project that replaces LITERAL_AGG output positions with their + // literal values and passes through all other fields via FieldReference. + // + // The aggregate output schema is: [grouping fields..., real agg measures...] + // The full output schema requested is: [grouping fields..., all agg calls (in original order)] + // For each position in the original agg call list: + // - real measure → FieldReference into the aggregate output + // - LITERAL_AGG → the literal value from aggCall.rexList + final int groupingFieldCount = + Math.toIntExact( + groupings.stream().flatMap(g -> g.getExpressions().stream()).distinct().count()); + final int realAggCount = aggCalls.size(); + final int totalAggOutputFields = groupingFieldCount + realAggCount; + + // Build the project expression list: grouping fields first, then one expression per original + // agg call in declaration order. + List projectExprs = new ArrayList<>(); + for (int i = 0; i < groupingFieldCount; i++) { + projectExprs.add(FieldReference.newRootStructReference(i, aggRel.getRecordType())); + } + int realAggIndex = groupingFieldCount; // tracks next real-measure field index in aggRel output + for (AggregateCall aggCall : aggregate.getAggCallList()) { + if (literalAggCalls.contains(aggCall)) { + // Convert the RexLiteral stored in rexList to a Substrait literal expression + RexNode rexLiteral = Iterables.getOnlyElement(aggCall.rexList); + projectExprs.add(toExpression(rexLiteral)); + } else if (!groupIdCalls.contains(aggCall)) { + // real measure: pass through by reference + projectExprs.add( + FieldReference.newRootStructReference(realAggIndex, aggRel.getRecordType())); + realAggIndex++; + } + // GROUP_ID calls are not present in the outer schema here (groupings.size() <= 1 branch); + // if groupings.size() > 1 they are handled by the remap above and should not appear here + } + + return Project.builder() + .remap(Remap.offset(totalAggOutputFields, projectExprs.size())) + .expressions(projectExprs) + .input(aggRel) + .build(); } Aggregate.Grouping fromGroupSet(ImmutableBitSet bitSet, Rel input) { diff --git a/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java index 8f392aae2..797a701c2 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java @@ -1,18 +1,24 @@ package io.substrait.isthmus; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertEquals; import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.sql.SubstraitSqlToCalcite; +import io.substrait.relation.Project; +import io.substrait.relation.Rel; import java.io.IOException; import org.apache.calcite.plan.hep.HepPlanner; import org.apache.calcite.plan.hep.HepProgram; import org.apache.calcite.plan.hep.HepProgramBuilder; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelRoot; +import org.apache.calcite.rel.core.RelFactories; import org.apache.calcite.rel.rules.CoreRules; import org.apache.calcite.sql.parser.SqlParseException; +import org.apache.calcite.sql2rel.RelDecorrelator; import org.junit.jupiter.api.Test; class OptimizerIntegrationTest extends PlanTestBase { @@ -48,4 +54,64 @@ void conversionHandlesBuiltInSum0CallAddedByRule() throws SqlParseException, IOE // Conversion of the new plan should succeed SubstraitRelVisitor.convert(RelRoot.of(newPlan, relRoot.kind), EXTENSION_COLLECTION)); } + + /** + * Regression test for LITERAL_AGG handling in SubstraitRelVisitor. + * + *

Calcite's SubQueryRemoveRule (CALCITE-6945, landed in 1.38.0) rewrites correlated quantified + * comparisons (e.g. {@code <> SOME}) using {@code LITERAL_AGG(true)} as a null-presence + * indicator. SubstraitRelVisitor has no Substrait binding for {@code LITERAL_AGG}, so the + * conversion previously crashed with "UnsupportedOperationException: Unable to find binding for + * call LITERAL_AGG(true)". + * + * @see CALCITE-6945 PR + */ + @Test + void conversionHandlesLiteralAggInsertedBySubQueryRemoveRule() + throws SqlParseException, IOException { + // <> SOME with a correlated nullable column triggers SubQueryRemoveRule's + // quantified-comparison path, which inserts LITERAL_AGG(true) into the aggregate. + String query = + "select e1.l_orderkey from lineitem e1 " + + "where e1.l_quantity <> some (" + + " select l_quantity from lineitem e2 where e2.l_partkey = e1.l_partkey" + + ")"; + + RelRoot relRoot = SubstraitSqlToCalcite.convertQuery(query, TPCH_CATALOG); + + // Step 1 — SubQueryRemoveRule: rewrites RexSubQuery → LogicalCorrelate + LITERAL_AGG. + HepProgram subQueryProgram = + new HepProgramBuilder() + .addRuleInstance(CoreRules.FILTER_SUB_QUERY_TO_CORRELATE) + .build(); + HepPlanner hepPlanner = new HepPlanner(subQueryProgram); + hepPlanner.setRoot(relRoot.rel); + RelNode afterSubQueryRemove = hepPlanner.findBestExp(); + + // Step 2 — RelDecorrelator: rewrites LogicalCorrelate → LEFT JOIN; LITERAL_AGG survives in + // the aggregate as a synthetic null-presence indicator column. + RelNode decorrelated = + RelDecorrelator.decorrelateQuery( + afterSubQueryRemove, + RelFactories.LOGICAL_BUILDER.create(relRoot.rel.getCluster(), null)); + + // Conversion must succeed and produce the correct output schema. + // The query selects a single column (l_orderkey), so the plan root must expose 1 field. + // The LITERAL_AGG wrapper emits a Project on top of the aggregate; verify that structure. + io.substrait.plan.Plan.Root planRoot = + assertDoesNotThrow( + () -> + SubstraitRelVisitor.convert( + RelRoot.of(decorrelated, relRoot.kind), EXTENSION_COLLECTION)); + Rel result = planRoot.getInput(); + + // The outermost Rel visible to the caller is a Project that re-inserts the LITERAL_AGG + // literal and passes real measures through — it must expose exactly 1 output field + // (l_orderkey) matching the SELECT list. + assertInstanceOf(Project.class, result, "expected LITERAL_AGG wrapper Project at plan root"); + assertEquals( + 1, + result.getRecordType().fields().size(), + "output schema should have exactly 1 field (l_orderkey)"); + } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/Substrait2SqlTest.java b/isthmus/src/test/java/io/substrait/isthmus/Substrait2SqlTest.java index e8a5e7c23..7d363e922 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/Substrait2SqlTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/Substrait2SqlTest.java @@ -219,4 +219,5 @@ void dateFunctions() throws Exception { assertSqlSubstraitRelRoundTrip( "select extract(month from o_orderdate),extract(year from o_orderdate),extract(day from o_orderdate) from orders"); } + } From 713501853016ab8d46660ee5c10f37281cc8c4e1 Mon Sep 17 00:00:00 2001 From: matthew brian white Date: Fri, 26 Jun 2026 15:28:43 +0100 Subject: [PATCH 2/4] fix: spotless apply Signed-off-by: matthew brian white --- .../main/java/io/substrait/isthmus/SubstraitRelVisitor.java | 2 +- .../java/io/substrait/isthmus/OptimizerIntegrationTest.java | 6 ++---- .../test/java/io/substrait/isthmus/Substrait2SqlTest.java | 1 - 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java index e7106ac30..bac6bf5d5 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java @@ -1,5 +1,6 @@ package io.substrait.isthmus; +import com.google.common.collect.Iterables; import io.substrait.expression.AggregateFunctionInvocation; import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; @@ -57,7 +58,6 @@ import org.apache.calcite.rex.RexFieldAccess; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; -import com.google.common.collect.Iterables; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.util.ImmutableBitSet; diff --git a/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java index 797a701c2..9c14780d2 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java @@ -1,8 +1,8 @@ package io.substrait.isthmus; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; -import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.extension.SimpleExtension; @@ -81,9 +81,7 @@ void conversionHandlesLiteralAggInsertedBySubQueryRemoveRule() // Step 1 — SubQueryRemoveRule: rewrites RexSubQuery → LogicalCorrelate + LITERAL_AGG. HepProgram subQueryProgram = - new HepProgramBuilder() - .addRuleInstance(CoreRules.FILTER_SUB_QUERY_TO_CORRELATE) - .build(); + new HepProgramBuilder().addRuleInstance(CoreRules.FILTER_SUB_QUERY_TO_CORRELATE).build(); HepPlanner hepPlanner = new HepPlanner(subQueryProgram); hepPlanner.setRoot(relRoot.rel); RelNode afterSubQueryRemove = hepPlanner.findBestExp(); diff --git a/isthmus/src/test/java/io/substrait/isthmus/Substrait2SqlTest.java b/isthmus/src/test/java/io/substrait/isthmus/Substrait2SqlTest.java index 7d363e922..e8a5e7c23 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/Substrait2SqlTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/Substrait2SqlTest.java @@ -219,5 +219,4 @@ void dateFunctions() throws Exception { assertSqlSubstraitRelRoundTrip( "select extract(month from o_orderdate),extract(year from o_orderdate),extract(day from o_orderdate) from orders"); } - } From ed599da11d73694a3e26b53c85586e2f029db291 Mon Sep 17 00:00:00 2001 From: matthew brian white Date: Wed, 1 Jul 2026 09:21:38 +0100 Subject: [PATCH 3/4] fix(isthmus): use newInputRelReference for LITERAL_AGG wrapper passthroughs newRootStructReference(i, aggRel.getRecordType()) typed every passthrough column as the whole aggregate struct instead of the scalar field at that position. Switch to newInputRelReference(i, aggRel) which derives the correct per-field type from the aggregate's record type, matching the pattern used by fromGroupSet(). Strengthen the regression test to: - guard that the plan actually contains LITERAL_AGG (detects silent Calcite version regressions) - assert the re-inserted literal is BoolLiteral(true) - assert no passthrough column carries a Struct type (catches the bug directly) - assert the wrapper subtree round-trips through proto with schema intact Add a second test verifying that LITERAL_AGG combined with GROUPING SETS throws UnsupportedOperationException. --- .../isthmus/SubstraitRelVisitor.java | 5 +- .../isthmus/OptimizerIntegrationTest.java | 130 ++++++++++++++---- 2 files changed, 106 insertions(+), 29 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java index bac6bf5d5..75a679148 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java @@ -440,7 +440,7 @@ public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) { // agg call in declaration order. List projectExprs = new ArrayList<>(); for (int i = 0; i < groupingFieldCount; i++) { - projectExprs.add(FieldReference.newRootStructReference(i, aggRel.getRecordType())); + projectExprs.add(FieldReference.newInputRelReference(i, aggRel)); } int realAggIndex = groupingFieldCount; // tracks next real-measure field index in aggRel output for (AggregateCall aggCall : aggregate.getAggCallList()) { @@ -450,8 +450,7 @@ public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) { projectExprs.add(toExpression(rexLiteral)); } else if (!groupIdCalls.contains(aggCall)) { // real measure: pass through by reference - projectExprs.add( - FieldReference.newRootStructReference(realAggIndex, aggRel.getRecordType())); + projectExprs.add(FieldReference.newInputRelReference(realAggIndex, aggRel)); realAggIndex++; } // GROUP_ID calls are not present in the outer schema here (groupings.size() <= 1 branch); diff --git a/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java index 9c14780d2..2b669f7d1 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java @@ -2,23 +2,39 @@ import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import io.substrait.extension.DefaultExtensionCatalog; +import io.substrait.extension.ExtensionCollector; import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.sql.SubstraitSqlToCalcite; +import io.substrait.relation.Aggregate; import io.substrait.relation.Project; +import io.substrait.relation.ProtoRelConverter; import io.substrait.relation.Rel; +import io.substrait.relation.RelProtoConverter; +import io.substrait.type.Type; import java.io.IOException; +import java.util.List; +import java.util.Optional; import org.apache.calcite.plan.hep.HepPlanner; import org.apache.calcite.plan.hep.HepProgram; import org.apache.calcite.plan.hep.HepProgramBuilder; +import org.apache.calcite.rel.RelCollations; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelRoot; +import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.RelFactories; +import org.apache.calcite.rel.logical.LogicalAggregate; import org.apache.calcite.rel.rules.CoreRules; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.fun.SqlInternalOperators; import org.apache.calcite.sql.parser.SqlParseException; +import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql2rel.RelDecorrelator; +import org.apache.calcite.util.ImmutableBitSet; import org.junit.jupiter.api.Test; class OptimizerIntegrationTest extends PlanTestBase { @@ -69,8 +85,6 @@ void conversionHandlesBuiltInSum0CallAddedByRule() throws SqlParseException, IOE @Test void conversionHandlesLiteralAggInsertedBySubQueryRemoveRule() throws SqlParseException, IOException { - // <> SOME with a correlated nullable column triggers SubQueryRemoveRule's - // quantified-comparison path, which inserts LITERAL_AGG(true) into the aggregate. String query = "select e1.l_orderkey from lineitem e1 " + "where e1.l_quantity <> some (" @@ -78,38 +92,102 @@ void conversionHandlesLiteralAggInsertedBySubQueryRemoveRule() + ")"; RelRoot relRoot = SubstraitSqlToCalcite.convertQuery(query, TPCH_CATALOG); - - // Step 1 — SubQueryRemoveRule: rewrites RexSubQuery → LogicalCorrelate + LITERAL_AGG. - HepProgram subQueryProgram = - new HepProgramBuilder().addRuleInstance(CoreRules.FILTER_SUB_QUERY_TO_CORRELATE).build(); - HepPlanner hepPlanner = new HepPlanner(subQueryProgram); + HepPlanner hepPlanner = + new HepPlanner( + new HepProgramBuilder() + .addRuleInstance(CoreRules.FILTER_SUB_QUERY_TO_CORRELATE) + .build()); hepPlanner.setRoot(relRoot.rel); - RelNode afterSubQueryRemove = hepPlanner.findBestExp(); - - // Step 2 — RelDecorrelator: rewrites LogicalCorrelate → LEFT JOIN; LITERAL_AGG survives in - // the aggregate as a synthetic null-presence indicator column. RelNode decorrelated = RelDecorrelator.decorrelateQuery( - afterSubQueryRemove, + hepPlanner.findBestExp(), RelFactories.LOGICAL_BUILDER.create(relRoot.rel.getCluster(), null)); - // Conversion must succeed and produce the correct output schema. - // The query selects a single column (l_orderkey), so the plan root must expose 1 field. - // The LITERAL_AGG wrapper emits a Project on top of the aggregate; verify that structure. + // Pin the trigger so a future Calcite bump can't silently stop exercising this path. + assertTrue(containsLiteralAgg(decorrelated), "test setup no longer produces LITERAL_AGG"); + io.substrait.plan.Plan.Root planRoot = assertDoesNotThrow( () -> SubstraitRelVisitor.convert( RelRoot.of(decorrelated, relRoot.kind), EXTENSION_COLLECTION)); - Rel result = planRoot.getInput(); - - // The outermost Rel visible to the caller is a Project that re-inserts the LITERAL_AGG - // literal and passes real measures through — it must expose exactly 1 output field - // (l_orderkey) matching the SELECT list. - assertInstanceOf(Project.class, result, "expected LITERAL_AGG wrapper Project at plan root"); - assertEquals( - 1, - result.getRecordType().fields().size(), - "output schema should have exactly 1 field (l_orderkey)"); + + // The fix inserts a Project directly over the Aggregate; inspect THAT, not the outer SELECT. + Project wrapper = + findProjectOverAggregate(planRoot.getInput()) + .orElseThrow(() -> new AssertionError("expected a Project wrapping the Aggregate")); + + assertTrue( + wrapper.getExpressions().stream() + .anyMatch( + e -> + e instanceof io.substrait.expression.Expression.BoolLiteral + && ((io.substrait.expression.Expression.BoolLiteral) e).value()), + "LITERAL_AGG(true) should be re-inserted as a boolean true literal"); + + // Passthroughs must carry the scalar field type, not the whole aggregate struct. + assertTrue( + wrapper.getRecordType().fields().stream().noneMatch(f -> f instanceof Type.Struct), + "wrapper columns must be scalar; fields=" + wrapper.getRecordType().fields()); + + // The wrapper subtree must survive a proto round-trip with its schema intact. + ExtensionCollector ec = new ExtensionCollector(); + io.substrait.proto.Rel proto = new RelProtoConverter(ec).toProto(wrapper); + Rel rt = new ProtoRelConverter(ec, extensions).from(proto); + assertEquals(wrapper.getRecordType(), rt.getRecordType(), "wrapper schema must round-trip"); + } + + @Test + void literalAggCombinedWithGroupingSetsIsRejected() { + RelNode input = builder.values(new String[] {"a", "b"}, 1, 2, 3, 4).build(); + RexBuilder rexBuilder = creator.rex(); + AggregateCall literalAgg = + AggregateCall.create( + SqlInternalOperators.LITERAL_AGG, + false, + false, + false, + List.of(rexBuilder.makeLiteral(true)), + List.of(), + -1, + null, + RelCollations.EMPTY, + typeFactory.createSqlType(SqlTypeName.BOOLEAN), + "li"); + ImmutableBitSet g0 = ImmutableBitSet.of(0); + ImmutableBitSet g1 = ImmutableBitSet.of(1); + RelNode aggregate = + LogicalAggregate.create( + input, List.of(), g0.union(g1), List.of(g0, g1), List.of(literalAgg)); + + UnsupportedOperationException ex = + assertThrows( + UnsupportedOperationException.class, + () -> + SubstraitRelVisitor.convert( + RelRoot.of(aggregate, org.apache.calcite.sql.SqlKind.SELECT), + EXTENSION_COLLECTION)); + assertTrue(ex.getMessage().contains("GROUPING SETS"), ex.getMessage()); + } + + private static boolean containsLiteralAgg(RelNode node) { + if (node instanceof org.apache.calcite.rel.core.Aggregate agg) { + if (agg.getAggCallList().stream() + .anyMatch(c -> c.getAggregation().getKind() == SqlKind.LITERAL_AGG)) { + return true; + } + } + return node.getInputs().stream().anyMatch(OptimizerIntegrationTest::containsLiteralAgg); + } + + private static Optional findProjectOverAggregate(Rel rel) { + if (rel instanceof Project p && p.getInput() instanceof Aggregate) { + return Optional.of(p); + } + return rel.getInputs().stream() + .map(OptimizerIntegrationTest::findProjectOverAggregate) + .filter(Optional::isPresent) + .map(Optional::get) + .findFirst(); } } From 78317c3ab03f46b770037f98956d9a1f57ca5c9e Mon Sep 17 00:00:00 2001 From: matthew brian white Date: Thu, 2 Jul 2026 09:47:26 +0100 Subject: [PATCH 4/4] fix: review comments Signed-off-by: matthew brian white --- .../isthmus/SubstraitRelVisitor.java | 52 +++++++++--------- .../isthmus/OptimizerIntegrationTest.java | 54 +++++++++++++++++++ 2 files changed, 82 insertions(+), 24 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java index 75a679148..bf5eedb75 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java @@ -345,17 +345,30 @@ public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) { sets.filter(s -> s != null).map(s -> fromGroupSet(s, input)).collect(Collectors.toList()); // get GROUP_ID() function calls - List groupIdCalls = + java.util.Set groupIdCalls = aggregate.getAggCallList().stream() .filter(c -> c.getAggregation().equals(SqlStdOperatorTable.GROUP_ID)) - .collect(Collectors.toList()); + .collect(Collectors.toSet()); // get LITERAL_AGG() function calls — injected by SubQueryRemoveRule (CALCITE-6945) as a // null-presence indicator; they carry a RexLiteral in rexList and have no Substrait binding. - List literalAggCalls = + java.util.Set literalAggCalls = aggregate.getAggCallList().stream() .filter(c -> c.getAggregation().getKind() == SqlKind.LITERAL_AGG) - .collect(Collectors.toList()); + .collect(Collectors.toSet()); + + if (!literalAggCalls.isEmpty() && groupings.size() > 1) { + throw new UnsupportedOperationException( + "LITERAL_AGG combined with GROUPING SETS / CUBE / ROLLUP is not supported"); + } + + // Number of distinct grouping-expression output fields produced by the aggregate. + // Used by the no-GROUP_ID remap and the LITERAL_AGG project wrapper below. + // The GROUP_ID remap branch intentionally uses a non-distinct count instead (see comment + // there). + final int groupingFieldCount = + Math.toIntExact( + groupings.stream().flatMap(g -> g.getExpressions().stream()).distinct().count()); List filteredAggCalls = aggregate.getAggCallList().stream() @@ -374,19 +387,19 @@ public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) { if (groupings.size() > 1) { // remove the grouping set index if there was no explicit GROUP_ID() function call if (groupIdCalls.isEmpty()) { - int groupingExprSize = - Math.toIntExact( - groupings.stream().flatMap(g -> g.getExpressions().stream()).distinct().count()); - builder.remap(Remap.offset(0, groupingExprSize + aggCalls.size())); + builder.remap(Remap.offset(0, groupingFieldCount + aggCalls.size())); } else { - // remap grouping set index at the field positions where the GROUP_ID() function calls were - final int groupingFieldCount = + // remap grouping set index at the field positions where the GROUP_ID() function calls were. + // Use the non-distinct total here: when grouping sets share expressions the aggregate + // output + // contains one slot per (groupingSet × expression) entry, not one per distinct expression. + final int groupingFieldCountWithDuplicates = Math.toIntExact(groupings.stream().flatMap(g -> g.getExpressions().stream()).count()); final int filterAggCallCount = aggCalls.size(); - final Integer groupingSetIndex = groupingFieldCount + filterAggCallCount; + final Integer groupingSetIndex = groupingFieldCountWithDuplicates + filterAggCallCount; final List remap = - IntStream.range(0, groupingFieldCount) + IntStream.range(0, groupingFieldCountWithDuplicates) .mapToObj(i -> i) .collect(Collectors.toCollection(ArrayList::new)); @@ -394,11 +407,10 @@ public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) { AggregateCall aggCall = aggregate.getAggCallList().get(i); if (filteredAggCalls.contains(aggCall)) { remap.add( - i + groupingFieldCount, filteredAggCalls.indexOf(aggCall) + groupingFieldCount); + i + groupingFieldCountWithDuplicates, + filteredAggCalls.indexOf(aggCall) + groupingFieldCountWithDuplicates); } else if (groupIdCalls.contains(aggCall)) { - remap.add(i + groupingFieldCount, groupingSetIndex); - } else if (literalAggCalls.contains(aggCall)) { - // LITERAL_AGG handled below via Project wrapper — skip remap slot for now + remap.add(i + groupingFieldCountWithDuplicates, groupingSetIndex); } else { // this should never get triggered throw new IllegalStateException( @@ -417,11 +429,6 @@ public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) { return aggRel; } - if (groupings.size() > 1) { - throw new UnsupportedOperationException( - "LITERAL_AGG combined with GROUPING SETS / CUBE / ROLLUP is not supported"); - } - // Wrap the aggregate in a Project that replaces LITERAL_AGG output positions with their // literal values and passes through all other fields via FieldReference. // @@ -430,9 +437,6 @@ public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) { // For each position in the original agg call list: // - real measure → FieldReference into the aggregate output // - LITERAL_AGG → the literal value from aggCall.rexList - final int groupingFieldCount = - Math.toIntExact( - groupings.stream().flatMap(g -> g.getExpressions().stream()).distinct().count()); final int realAggCount = aggCalls.size(); final int totalAggOutputFields = groupingFieldCount + realAggCount; diff --git a/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java index 2b669f7d1..5836d9821 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java @@ -31,6 +31,7 @@ import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.fun.SqlInternalOperators; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.parser.SqlParseException; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql2rel.RelDecorrelator; @@ -170,6 +171,59 @@ void literalAggCombinedWithGroupingSetsIsRejected() { assertTrue(ex.getMessage().contains("GROUPING SETS"), ex.getMessage()); } + /** + * Reproduces the crash from the review comment: LITERAL_AGG first in the agg-call list, followed + * by GROUP_ID, with multiple grouping sets. The remap loop used to leave a gap and throw + * IndexOutOfBoundsException; now it should throw the clean UnsupportedOperationException before + * reaching the remap work. + */ + @Test + void literalAggBeforeGroupIdWithGroupingSetsIsRejected() { + RelNode input = builder.values(new String[] {"a", "b"}, 1, 2, 3, 4).build(); + RexBuilder rexBuilder = creator.rex(); + AggregateCall literalAgg = + AggregateCall.create( + SqlInternalOperators.LITERAL_AGG, + false, + false, + false, + List.of(rexBuilder.makeLiteral(true)), + List.of(), + -1, + null, + RelCollations.EMPTY, + typeFactory.createSqlType(SqlTypeName.BOOLEAN), + "li"); + AggregateCall groupIdCall = + AggregateCall.create( + SqlStdOperatorTable.GROUP_ID, + false, + false, + false, + List.of(), + List.of(), + -1, + null, + RelCollations.EMPTY, + typeFactory.createSqlType(SqlTypeName.BIGINT), + "gid"); + ImmutableBitSet g0 = ImmutableBitSet.of(0); + ImmutableBitSet g1 = ImmutableBitSet.of(1); + // LITERAL_AGG at index 0, GROUP_ID at index 1 — the ordering that triggered the crash + RelNode aggregate = + LogicalAggregate.create( + input, List.of(), g0.union(g1), List.of(g0, g1), List.of(literalAgg, groupIdCall)); + + UnsupportedOperationException ex = + assertThrows( + UnsupportedOperationException.class, + () -> + SubstraitRelVisitor.convert( + RelRoot.of(aggregate, org.apache.calcite.sql.SqlKind.SELECT), + EXTENSION_COLLECTION)); + assertTrue(ex.getMessage().contains("GROUPING SETS"), ex.getMessage()); + } + private static boolean containsLiteralAgg(RelNode node) { if (node instanceof org.apache.calcite.rel.core.Aggregate agg) { if (agg.getAggCallList().stream()