diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java index bc4d0a9f2..bf5eedb75 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java @@ -1,5 +1,6 @@ package io.substrait.isthmus; +import com.google.common.collect.Iterables; import io.substrait.expression.AggregateFunctionInvocation; import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; @@ -57,6 +58,7 @@ import org.apache.calcite.rex.RexFieldAccess; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.util.ImmutableBitSet; import org.immutables.value.Value; @@ -343,15 +345,35 @@ public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) { sets.filter(s -> s != null).map(s -> fromGroupSet(s, input)).collect(Collectors.toList()); // get GROUP_ID() function calls - List groupIdCalls = + java.util.Set groupIdCalls = aggregate.getAggCallList().stream() .filter(c -> c.getAggregation().equals(SqlStdOperatorTable.GROUP_ID)) - .collect(Collectors.toList()); + .collect(Collectors.toSet()); + + // get LITERAL_AGG() function calls — injected by SubQueryRemoveRule (CALCITE-6945) as a + // null-presence indicator; they carry a RexLiteral in rexList and have no Substrait binding. + java.util.Set literalAggCalls = + aggregate.getAggCallList().stream() + .filter(c -> c.getAggregation().getKind() == SqlKind.LITERAL_AGG) + .collect(Collectors.toSet()); + + if (!literalAggCalls.isEmpty() && groupings.size() > 1) { + throw new UnsupportedOperationException( + "LITERAL_AGG combined with GROUPING SETS / CUBE / ROLLUP is not supported"); + } + + // Number of distinct grouping-expression output fields produced by the aggregate. + // Used by the no-GROUP_ID remap and the LITERAL_AGG project wrapper below. + // The GROUP_ID remap branch intentionally uses a non-distinct count instead (see comment + // there). + final int groupingFieldCount = + Math.toIntExact( + groupings.stream().flatMap(g -> g.getExpressions().stream()).distinct().count()); List filteredAggCalls = aggregate.getAggCallList().stream() - // remove GROUP_ID() function calls - .filter(c -> !groupIdCalls.contains(c)) + // remove GROUP_ID() and LITERAL_AGG() function calls + .filter(c -> !groupIdCalls.contains(c) && !literalAggCalls.contains(c)) .collect(Collectors.toList()); List aggCalls = @@ -365,19 +387,19 @@ public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) { if (groupings.size() > 1) { // remove the grouping set index if there was no explicit GROUP_ID() function call if (groupIdCalls.isEmpty()) { - int groupingExprSize = - Math.toIntExact( - groupings.stream().flatMap(g -> g.getExpressions().stream()).distinct().count()); - builder.remap(Remap.offset(0, groupingExprSize + aggCalls.size())); + builder.remap(Remap.offset(0, groupingFieldCount + aggCalls.size())); } else { - // remap grouping set index at the field positions where the GROUP_ID() function calls were - final int groupingFieldCount = + // remap grouping set index at the field positions where the GROUP_ID() function calls were. + // Use the non-distinct total here: when grouping sets share expressions the aggregate + // output + // contains one slot per (groupingSet × expression) entry, not one per distinct expression. + final int groupingFieldCountWithDuplicates = Math.toIntExact(groupings.stream().flatMap(g -> g.getExpressions().stream()).count()); final int filterAggCallCount = aggCalls.size(); - final Integer groupingSetIndex = groupingFieldCount + filterAggCallCount; + final Integer groupingSetIndex = groupingFieldCountWithDuplicates + filterAggCallCount; final List remap = - IntStream.range(0, groupingFieldCount) + IntStream.range(0, groupingFieldCountWithDuplicates) .mapToObj(i -> i) .collect(Collectors.toCollection(ArrayList::new)); @@ -385,9 +407,10 @@ public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) { AggregateCall aggCall = aggregate.getAggCallList().get(i); if (filteredAggCalls.contains(aggCall)) { remap.add( - i + groupingFieldCount, filteredAggCalls.indexOf(aggCall) + groupingFieldCount); + i + groupingFieldCountWithDuplicates, + filteredAggCalls.indexOf(aggCall) + groupingFieldCountWithDuplicates); } else if (groupIdCalls.contains(aggCall)) { - remap.add(i + groupingFieldCount, groupingSetIndex); + remap.add(i + groupingFieldCountWithDuplicates, groupingSetIndex); } else { // this should never get triggered throw new IllegalStateException( @@ -400,7 +423,49 @@ public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) { } } - return builder.build(); + Rel aggRel = builder.build(); + + if (literalAggCalls.isEmpty()) { + return aggRel; + } + + // Wrap the aggregate in a Project that replaces LITERAL_AGG output positions with their + // literal values and passes through all other fields via FieldReference. + // + // The aggregate output schema is: [grouping fields..., real agg measures...] + // The full output schema requested is: [grouping fields..., all agg calls (in original order)] + // For each position in the original agg call list: + // - real measure → FieldReference into the aggregate output + // - LITERAL_AGG → the literal value from aggCall.rexList + final int realAggCount = aggCalls.size(); + final int totalAggOutputFields = groupingFieldCount + realAggCount; + + // Build the project expression list: grouping fields first, then one expression per original + // agg call in declaration order. + List projectExprs = new ArrayList<>(); + for (int i = 0; i < groupingFieldCount; i++) { + projectExprs.add(FieldReference.newInputRelReference(i, aggRel)); + } + int realAggIndex = groupingFieldCount; // tracks next real-measure field index in aggRel output + for (AggregateCall aggCall : aggregate.getAggCallList()) { + if (literalAggCalls.contains(aggCall)) { + // Convert the RexLiteral stored in rexList to a Substrait literal expression + RexNode rexLiteral = Iterables.getOnlyElement(aggCall.rexList); + projectExprs.add(toExpression(rexLiteral)); + } else if (!groupIdCalls.contains(aggCall)) { + // real measure: pass through by reference + projectExprs.add(FieldReference.newInputRelReference(realAggIndex, aggRel)); + realAggIndex++; + } + // GROUP_ID calls are not present in the outer schema here (groupings.size() <= 1 branch); + // if groupings.size() > 1 they are handled by the remap above and should not appear here + } + + return Project.builder() + .remap(Remap.offset(totalAggOutputFields, projectExprs.size())) + .expressions(projectExprs) + .input(aggRel) + .build(); } Aggregate.Grouping fromGroupSet(ImmutableBitSet bitSet, Rel input) { diff --git a/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java index 8f392aae2..5836d9821 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java @@ -1,18 +1,41 @@ package io.substrait.isthmus; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import io.substrait.extension.DefaultExtensionCatalog; +import io.substrait.extension.ExtensionCollector; import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.sql.SubstraitSqlToCalcite; +import io.substrait.relation.Aggregate; +import io.substrait.relation.Project; +import io.substrait.relation.ProtoRelConverter; +import io.substrait.relation.Rel; +import io.substrait.relation.RelProtoConverter; +import io.substrait.type.Type; import java.io.IOException; +import java.util.List; +import java.util.Optional; import org.apache.calcite.plan.hep.HepPlanner; import org.apache.calcite.plan.hep.HepProgram; import org.apache.calcite.plan.hep.HepProgramBuilder; +import org.apache.calcite.rel.RelCollations; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelRoot; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.RelFactories; +import org.apache.calcite.rel.logical.LogicalAggregate; import org.apache.calcite.rel.rules.CoreRules; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.fun.SqlInternalOperators; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.parser.SqlParseException; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql2rel.RelDecorrelator; +import org.apache.calcite.util.ImmutableBitSet; import org.junit.jupiter.api.Test; class OptimizerIntegrationTest extends PlanTestBase { @@ -48,4 +71,177 @@ void conversionHandlesBuiltInSum0CallAddedByRule() throws SqlParseException, IOE // Conversion of the new plan should succeed SubstraitRelVisitor.convert(RelRoot.of(newPlan, relRoot.kind), EXTENSION_COLLECTION)); } + + /** + * Regression test for LITERAL_AGG handling in SubstraitRelVisitor. + * + *

Calcite's SubQueryRemoveRule (CALCITE-6945, landed in 1.38.0) rewrites correlated quantified + * comparisons (e.g. {@code <> SOME}) using {@code LITERAL_AGG(true)} as a null-presence + * indicator. SubstraitRelVisitor has no Substrait binding for {@code LITERAL_AGG}, so the + * conversion previously crashed with "UnsupportedOperationException: Unable to find binding for + * call LITERAL_AGG(true)". + * + * @see CALCITE-6945 PR + */ + @Test + void conversionHandlesLiteralAggInsertedBySubQueryRemoveRule() + throws SqlParseException, IOException { + String query = + "select e1.l_orderkey from lineitem e1 " + + "where e1.l_quantity <> some (" + + " select l_quantity from lineitem e2 where e2.l_partkey = e1.l_partkey" + + ")"; + + RelRoot relRoot = SubstraitSqlToCalcite.convertQuery(query, TPCH_CATALOG); + HepPlanner hepPlanner = + new HepPlanner( + new HepProgramBuilder() + .addRuleInstance(CoreRules.FILTER_SUB_QUERY_TO_CORRELATE) + .build()); + hepPlanner.setRoot(relRoot.rel); + RelNode decorrelated = + RelDecorrelator.decorrelateQuery( + hepPlanner.findBestExp(), + RelFactories.LOGICAL_BUILDER.create(relRoot.rel.getCluster(), null)); + + // Pin the trigger so a future Calcite bump can't silently stop exercising this path. + assertTrue(containsLiteralAgg(decorrelated), "test setup no longer produces LITERAL_AGG"); + + io.substrait.plan.Plan.Root planRoot = + assertDoesNotThrow( + () -> + SubstraitRelVisitor.convert( + RelRoot.of(decorrelated, relRoot.kind), EXTENSION_COLLECTION)); + + // The fix inserts a Project directly over the Aggregate; inspect THAT, not the outer SELECT. + Project wrapper = + findProjectOverAggregate(planRoot.getInput()) + .orElseThrow(() -> new AssertionError("expected a Project wrapping the Aggregate")); + + assertTrue( + wrapper.getExpressions().stream() + .anyMatch( + e -> + e instanceof io.substrait.expression.Expression.BoolLiteral + && ((io.substrait.expression.Expression.BoolLiteral) e).value()), + "LITERAL_AGG(true) should be re-inserted as a boolean true literal"); + + // Passthroughs must carry the scalar field type, not the whole aggregate struct. + assertTrue( + wrapper.getRecordType().fields().stream().noneMatch(f -> f instanceof Type.Struct), + "wrapper columns must be scalar; fields=" + wrapper.getRecordType().fields()); + + // The wrapper subtree must survive a proto round-trip with its schema intact. + ExtensionCollector ec = new ExtensionCollector(); + io.substrait.proto.Rel proto = new RelProtoConverter(ec).toProto(wrapper); + Rel rt = new ProtoRelConverter(ec, extensions).from(proto); + assertEquals(wrapper.getRecordType(), rt.getRecordType(), "wrapper schema must round-trip"); + } + + @Test + void literalAggCombinedWithGroupingSetsIsRejected() { + RelNode input = builder.values(new String[] {"a", "b"}, 1, 2, 3, 4).build(); + RexBuilder rexBuilder = creator.rex(); + AggregateCall literalAgg = + AggregateCall.create( + SqlInternalOperators.LITERAL_AGG, + false, + false, + false, + List.of(rexBuilder.makeLiteral(true)), + List.of(), + -1, + null, + RelCollations.EMPTY, + typeFactory.createSqlType(SqlTypeName.BOOLEAN), + "li"); + ImmutableBitSet g0 = ImmutableBitSet.of(0); + ImmutableBitSet g1 = ImmutableBitSet.of(1); + RelNode aggregate = + LogicalAggregate.create( + input, List.of(), g0.union(g1), List.of(g0, g1), List.of(literalAgg)); + + UnsupportedOperationException ex = + assertThrows( + UnsupportedOperationException.class, + () -> + SubstraitRelVisitor.convert( + RelRoot.of(aggregate, org.apache.calcite.sql.SqlKind.SELECT), + EXTENSION_COLLECTION)); + assertTrue(ex.getMessage().contains("GROUPING SETS"), ex.getMessage()); + } + + /** + * Reproduces the crash from the review comment: LITERAL_AGG first in the agg-call list, followed + * by GROUP_ID, with multiple grouping sets. The remap loop used to leave a gap and throw + * IndexOutOfBoundsException; now it should throw the clean UnsupportedOperationException before + * reaching the remap work. + */ + @Test + void literalAggBeforeGroupIdWithGroupingSetsIsRejected() { + RelNode input = builder.values(new String[] {"a", "b"}, 1, 2, 3, 4).build(); + RexBuilder rexBuilder = creator.rex(); + AggregateCall literalAgg = + AggregateCall.create( + SqlInternalOperators.LITERAL_AGG, + false, + false, + false, + List.of(rexBuilder.makeLiteral(true)), + List.of(), + -1, + null, + RelCollations.EMPTY, + typeFactory.createSqlType(SqlTypeName.BOOLEAN), + "li"); + AggregateCall groupIdCall = + AggregateCall.create( + SqlStdOperatorTable.GROUP_ID, + false, + false, + false, + List.of(), + List.of(), + -1, + null, + RelCollations.EMPTY, + typeFactory.createSqlType(SqlTypeName.BIGINT), + "gid"); + ImmutableBitSet g0 = ImmutableBitSet.of(0); + ImmutableBitSet g1 = ImmutableBitSet.of(1); + // LITERAL_AGG at index 0, GROUP_ID at index 1 — the ordering that triggered the crash + RelNode aggregate = + LogicalAggregate.create( + input, List.of(), g0.union(g1), List.of(g0, g1), List.of(literalAgg, groupIdCall)); + + UnsupportedOperationException ex = + assertThrows( + UnsupportedOperationException.class, + () -> + SubstraitRelVisitor.convert( + RelRoot.of(aggregate, org.apache.calcite.sql.SqlKind.SELECT), + EXTENSION_COLLECTION)); + assertTrue(ex.getMessage().contains("GROUPING SETS"), ex.getMessage()); + } + + private static boolean containsLiteralAgg(RelNode node) { + if (node instanceof org.apache.calcite.rel.core.Aggregate agg) { + if (agg.getAggCallList().stream() + .anyMatch(c -> c.getAggregation().getKind() == SqlKind.LITERAL_AGG)) { + return true; + } + } + return node.getInputs().stream().anyMatch(OptimizerIntegrationTest::containsLiteralAgg); + } + + private static Optional findProjectOverAggregate(Rel rel) { + if (rel instanceof Project p && p.getInput() instanceof Aggregate) { + return Optional.of(p); + } + return rel.getInputs().stream() + .map(OptimizerIntegrationTest::findProjectOverAggregate) + .filter(Optional::isPresent) + .map(Optional::get) + .findFirst(); + } }