Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 80 additions & 15 deletions isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.substrait.isthmus;

import com.google.common.collect.Iterables;
import io.substrait.expression.AggregateFunctionInvocation;
import io.substrait.expression.Expression;
import io.substrait.expression.ExpressionCreator;
Expand Down Expand Up @@ -57,6 +58,7 @@
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.util.ImmutableBitSet;
import org.immutables.value.Value;
Expand Down Expand Up @@ -343,15 +345,35 @@ public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) {
sets.filter(s -> s != null).map(s -> fromGroupSet(s, input)).collect(Collectors.toList());

// get GROUP_ID() function calls
List<AggregateCall> groupIdCalls =
java.util.Set<AggregateCall> groupIdCalls =
aggregate.getAggCallList().stream()
.filter(c -> c.getAggregation().equals(SqlStdOperatorTable.GROUP_ID))
.collect(Collectors.toList());
.collect(Collectors.toSet());

// get LITERAL_AGG() function calls — injected by SubQueryRemoveRule (CALCITE-6945) as a
// null-presence indicator; they carry a RexLiteral in rexList and have no Substrait binding.
java.util.Set<AggregateCall> literalAggCalls =
aggregate.getAggCallList().stream()
.filter(c -> c.getAggregation().getKind() == SqlKind.LITERAL_AGG)
.collect(Collectors.toSet());

if (!literalAggCalls.isEmpty() && groupings.size() > 1) {
throw new UnsupportedOperationException(
"LITERAL_AGG combined with GROUPING SETS / CUBE / ROLLUP is not supported");
}

// Number of distinct grouping-expression output fields produced by the aggregate.
// Used by the no-GROUP_ID remap and the LITERAL_AGG project wrapper below.
// The GROUP_ID remap branch intentionally uses a non-distinct count instead (see comment
// there).
final int groupingFieldCount =
Math.toIntExact(
groupings.stream().flatMap(g -> g.getExpressions().stream()).distinct().count());

List<AggregateCall> filteredAggCalls =
aggregate.getAggCallList().stream()
// remove GROUP_ID() function calls
.filter(c -> !groupIdCalls.contains(c))
// remove GROUP_ID() and LITERAL_AGG() function calls
.filter(c -> !groupIdCalls.contains(c) && !literalAggCalls.contains(c))
.collect(Collectors.toList());

List<Measure> aggCalls =
Expand All @@ -365,29 +387,30 @@ public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) {
if (groupings.size() > 1) {
// remove the grouping set index if there was no explicit GROUP_ID() function call
if (groupIdCalls.isEmpty()) {
int groupingExprSize =
Math.toIntExact(
groupings.stream().flatMap(g -> g.getExpressions().stream()).distinct().count());
builder.remap(Remap.offset(0, groupingExprSize + aggCalls.size()));
builder.remap(Remap.offset(0, groupingFieldCount + aggCalls.size()));
} else {
// remap grouping set index at the field positions where the GROUP_ID() function calls were
final int groupingFieldCount =
// remap grouping set index at the field positions where the GROUP_ID() function calls were.
// Use the non-distinct total here: when grouping sets share expressions the aggregate
// output
// contains one slot per (groupingSet × expression) entry, not one per distinct expression.
final int groupingFieldCountWithDuplicates =
Math.toIntExact(groupings.stream().flatMap(g -> g.getExpressions().stream()).count());
final int filterAggCallCount = aggCalls.size();
final Integer groupingSetIndex = groupingFieldCount + filterAggCallCount;
final Integer groupingSetIndex = groupingFieldCountWithDuplicates + filterAggCallCount;

final List<Integer> remap =
IntStream.range(0, groupingFieldCount)
IntStream.range(0, groupingFieldCountWithDuplicates)
.mapToObj(i -> i)
.collect(Collectors.toCollection(ArrayList::new));

for (int i = 0; i < aggregate.getAggCallList().size(); i++) {
AggregateCall aggCall = aggregate.getAggCallList().get(i);
if (filteredAggCalls.contains(aggCall)) {
remap.add(
i + groupingFieldCount, filteredAggCalls.indexOf(aggCall) + groupingFieldCount);
i + groupingFieldCountWithDuplicates,
filteredAggCalls.indexOf(aggCall) + groupingFieldCountWithDuplicates);
} else if (groupIdCalls.contains(aggCall)) {
remap.add(i + groupingFieldCount, groupingSetIndex);
remap.add(i + groupingFieldCountWithDuplicates, groupingSetIndex);
} else {
// this should never get triggered
throw new IllegalStateException(
Expand All @@ -400,7 +423,49 @@ public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) {
}
}

return builder.build();
Rel aggRel = builder.build();

if (literalAggCalls.isEmpty()) {
return aggRel;
}

// Wrap the aggregate in a Project that replaces LITERAL_AGG output positions with their
// literal values and passes through all other fields via FieldReference.
//
// The aggregate output schema is: [grouping fields..., real agg measures...]
// The full output schema requested is: [grouping fields..., all agg calls (in original order)]
// For each position in the original agg call list:
// - real measure → FieldReference into the aggregate output
// - LITERAL_AGG → the literal value from aggCall.rexList
final int realAggCount = aggCalls.size();
final int totalAggOutputFields = groupingFieldCount + realAggCount;

// Build the project expression list: grouping fields first, then one expression per original
// agg call in declaration order.
List<Expression> projectExprs = new ArrayList<>();
for (int i = 0; i < groupingFieldCount; i++) {
projectExprs.add(FieldReference.newInputRelReference(i, aggRel));
}
int realAggIndex = groupingFieldCount; // tracks next real-measure field index in aggRel output
for (AggregateCall aggCall : aggregate.getAggCallList()) {
if (literalAggCalls.contains(aggCall)) {
// Convert the RexLiteral stored in rexList to a Substrait literal expression
RexNode rexLiteral = Iterables.getOnlyElement(aggCall.rexList);
projectExprs.add(toExpression(rexLiteral));
} else if (!groupIdCalls.contains(aggCall)) {
// real measure: pass through by reference
projectExprs.add(FieldReference.newInputRelReference(realAggIndex, aggRel));
realAggIndex++;
}
// GROUP_ID calls are not present in the outer schema here (groupings.size() <= 1 branch);
// if groupings.size() > 1 they are handled by the remap above and should not appear here
}

return Project.builder()
.remap(Remap.offset(totalAggOutputFields, projectExprs.size()))
.expressions(projectExprs)
.input(aggRel)
.build();
}

Aggregate.Grouping fromGroupSet(ImmutableBitSet bitSet, Rel input) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,41 @@
package io.substrait.isthmus;

import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.extension.ExtensionCollector;
import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.sql.SubstraitSqlToCalcite;
import io.substrait.relation.Aggregate;
import io.substrait.relation.Project;
import io.substrait.relation.ProtoRelConverter;
import io.substrait.relation.Rel;
import io.substrait.relation.RelProtoConverter;
import io.substrait.type.Type;
import java.io.IOException;
import java.util.List;
import java.util.Optional;
import org.apache.calcite.plan.hep.HepPlanner;
import org.apache.calcite.plan.hep.HepProgram;
import org.apache.calcite.plan.hep.HepProgramBuilder;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.rules.CoreRules;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlInternalOperators;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql2rel.RelDecorrelator;
import org.apache.calcite.util.ImmutableBitSet;
import org.junit.jupiter.api.Test;

class OptimizerIntegrationTest extends PlanTestBase {
Expand Down Expand Up @@ -48,4 +71,177 @@ void conversionHandlesBuiltInSum0CallAddedByRule() throws SqlParseException, IOE
// Conversion of the new plan should succeed
SubstraitRelVisitor.convert(RelRoot.of(newPlan, relRoot.kind), EXTENSION_COLLECTION));
}

/**
* Regression test for LITERAL_AGG handling in SubstraitRelVisitor.
*
* <p>Calcite's SubQueryRemoveRule (CALCITE-6945, landed in 1.38.0) rewrites correlated quantified
* comparisons (e.g. {@code <> SOME}) using {@code LITERAL_AGG(true)} as a null-presence
* indicator. SubstraitRelVisitor has no Substrait binding for {@code LITERAL_AGG}, so the
* conversion previously crashed with "UnsupportedOperationException: Unable to find binding for
* call LITERAL_AGG(true)".
*
* @see <a href="https://github.com/apache/calcite/pull/4296">CALCITE-6945 PR</a>
*/
@Test
void conversionHandlesLiteralAggInsertedBySubQueryRemoveRule()
throws SqlParseException, IOException {
String query =
"select e1.l_orderkey from lineitem e1 "
+ "where e1.l_quantity <> some ("
+ " select l_quantity from lineitem e2 where e2.l_partkey = e1.l_partkey"
+ ")";

RelRoot relRoot = SubstraitSqlToCalcite.convertQuery(query, TPCH_CATALOG);
HepPlanner hepPlanner =
new HepPlanner(
new HepProgramBuilder()
.addRuleInstance(CoreRules.FILTER_SUB_QUERY_TO_CORRELATE)
.build());
hepPlanner.setRoot(relRoot.rel);
RelNode decorrelated =
RelDecorrelator.decorrelateQuery(
hepPlanner.findBestExp(),
RelFactories.LOGICAL_BUILDER.create(relRoot.rel.getCluster(), null));

// Pin the trigger so a future Calcite bump can't silently stop exercising this path.
assertTrue(containsLiteralAgg(decorrelated), "test setup no longer produces LITERAL_AGG");

io.substrait.plan.Plan.Root planRoot =
assertDoesNotThrow(
() ->
SubstraitRelVisitor.convert(
RelRoot.of(decorrelated, relRoot.kind), EXTENSION_COLLECTION));

// The fix inserts a Project directly over the Aggregate; inspect THAT, not the outer SELECT.
Project wrapper =
findProjectOverAggregate(planRoot.getInput())
.orElseThrow(() -> new AssertionError("expected a Project wrapping the Aggregate"));

assertTrue(
wrapper.getExpressions().stream()
.anyMatch(
e ->
e instanceof io.substrait.expression.Expression.BoolLiteral
&& ((io.substrait.expression.Expression.BoolLiteral) e).value()),
"LITERAL_AGG(true) should be re-inserted as a boolean true literal");

// Passthroughs must carry the scalar field type, not the whole aggregate struct.
assertTrue(
wrapper.getRecordType().fields().stream().noneMatch(f -> f instanceof Type.Struct),
"wrapper columns must be scalar; fields=" + wrapper.getRecordType().fields());

// The wrapper subtree must survive a proto round-trip with its schema intact.
ExtensionCollector ec = new ExtensionCollector();
io.substrait.proto.Rel proto = new RelProtoConverter(ec).toProto(wrapper);
Rel rt = new ProtoRelConverter(ec, extensions).from(proto);
assertEquals(wrapper.getRecordType(), rt.getRecordType(), "wrapper schema must round-trip");
}

@Test
void literalAggCombinedWithGroupingSetsIsRejected() {
RelNode input = builder.values(new String[] {"a", "b"}, 1, 2, 3, 4).build();
RexBuilder rexBuilder = creator.rex();
AggregateCall literalAgg =
AggregateCall.create(
SqlInternalOperators.LITERAL_AGG,
false,
false,
false,
List.of(rexBuilder.makeLiteral(true)),
List.of(),
-1,
null,
RelCollations.EMPTY,
typeFactory.createSqlType(SqlTypeName.BOOLEAN),
"li");
ImmutableBitSet g0 = ImmutableBitSet.of(0);
ImmutableBitSet g1 = ImmutableBitSet.of(1);
RelNode aggregate =
LogicalAggregate.create(
input, List.of(), g0.union(g1), List.of(g0, g1), List.of(literalAgg));

UnsupportedOperationException ex =
assertThrows(
UnsupportedOperationException.class,
() ->
SubstraitRelVisitor.convert(
RelRoot.of(aggregate, org.apache.calcite.sql.SqlKind.SELECT),
EXTENSION_COLLECTION));
assertTrue(ex.getMessage().contains("GROUPING SETS"), ex.getMessage());
}

/**
* Reproduces the crash from the review comment: LITERAL_AGG first in the agg-call list, followed
* by GROUP_ID, with multiple grouping sets. The remap loop used to leave a gap and throw
* IndexOutOfBoundsException; now it should throw the clean UnsupportedOperationException before
* reaching the remap work.
*/
@Test
void literalAggBeforeGroupIdWithGroupingSetsIsRejected() {
RelNode input = builder.values(new String[] {"a", "b"}, 1, 2, 3, 4).build();
RexBuilder rexBuilder = creator.rex();
AggregateCall literalAgg =
AggregateCall.create(
SqlInternalOperators.LITERAL_AGG,
false,
false,
false,
List.of(rexBuilder.makeLiteral(true)),
List.of(),
-1,
null,
RelCollations.EMPTY,
typeFactory.createSqlType(SqlTypeName.BOOLEAN),
"li");
AggregateCall groupIdCall =
AggregateCall.create(
SqlStdOperatorTable.GROUP_ID,
false,
false,
false,
List.of(),
List.of(),
-1,
null,
RelCollations.EMPTY,
typeFactory.createSqlType(SqlTypeName.BIGINT),
"gid");
ImmutableBitSet g0 = ImmutableBitSet.of(0);
ImmutableBitSet g1 = ImmutableBitSet.of(1);
// LITERAL_AGG at index 0, GROUP_ID at index 1 — the ordering that triggered the crash
RelNode aggregate =
LogicalAggregate.create(
input, List.of(), g0.union(g1), List.of(g0, g1), List.of(literalAgg, groupIdCall));

UnsupportedOperationException ex =
assertThrows(
UnsupportedOperationException.class,
() ->
SubstraitRelVisitor.convert(
RelRoot.of(aggregate, org.apache.calcite.sql.SqlKind.SELECT),
EXTENSION_COLLECTION));
assertTrue(ex.getMessage().contains("GROUPING SETS"), ex.getMessage());
}

private static boolean containsLiteralAgg(RelNode node) {
if (node instanceof org.apache.calcite.rel.core.Aggregate agg) {
if (agg.getAggCallList().stream()
.anyMatch(c -> c.getAggregation().getKind() == SqlKind.LITERAL_AGG)) {
return true;
}
}
return node.getInputs().stream().anyMatch(OptimizerIntegrationTest::containsLiteralAgg);
}

private static Optional<Project> findProjectOverAggregate(Rel rel) {
if (rel instanceof Project p && p.getInput() instanceof Aggregate) {
return Optional.of(p);
}
return rel.getInputs().stream()
.map(OptimizerIntegrationTest::findProjectOverAggregate)
.filter(Optional::isPresent)
.map(Optional::get)
.findFirst();
}
}
Loading