Skip to content

Commit 9a7e29f

Browse files
committed
Apply complete PR opensearch-project#4612 changes: multi-field binning, AVG rounding, timestamp support
Signed-off-by: Kai Huang <[email protected]>
1 parent 3c31f22 commit 9a7e29f

File tree

1 file changed

+199
-23
lines changed

1 file changed

+199
-23
lines changed

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 199 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,22 @@ private void validateWildcardPatterns(
459459
}
460460
}
461461

462+
/** Extract field name from UnresolvedExpression. Handles Field and Alias expressions. */
463+
private String extractFieldName(UnresolvedExpression expr) {
464+
if (expr instanceof org.opensearch.sql.ast.expression.Field) {
465+
return ((org.opensearch.sql.ast.expression.Field) expr).getField().toString();
466+
} else if (expr instanceof org.opensearch.sql.ast.expression.Alias) {
467+
org.opensearch.sql.ast.expression.Alias alias =
468+
(org.opensearch.sql.ast.expression.Alias) expr;
469+
if (alias.getDelegated() instanceof org.opensearch.sql.ast.expression.Field) {
470+
return ((org.opensearch.sql.ast.expression.Field) alias.getDelegated())
471+
.getField()
472+
.toString();
473+
}
474+
}
475+
return null;
476+
}
477+
462478
private boolean isMetadataField(String fieldName) {
463479
return OpenSearchConstants.METADATAFIELD_TYPE_MAP.containsKey(fieldName);
464480
}
@@ -668,7 +684,26 @@ public RelNode visitBin(Bin node, CalcitePlanContext context) {
668684
RexNode binExpression = BinUtils.createBinExpression(node, fieldExpr, context, rexVisitor);
669685

670686
String alias = node.getAlias() != null ? node.getAlias() : fieldName;
671-
projectPlusOverriding(List.of(binExpression), List.of(alias), context);
687+
688+
// Check if this field is used in aggregation grouping with multiple fields
689+
if (context.getAggregationGroupByFields().contains(fieldName)
690+
&& context.getAggregationGroupByCount() > 1) {
691+
// For multi-field aggregation: preserve BOTH original field and binned field
692+
// The binned field (fieldName_bin) is used for grouping
693+
// The original field is used for MIN aggregation to show actual timestamps per group
694+
List<RexNode> allFields = new ArrayList<>(context.relBuilder.fields());
695+
List<String> allFieldNames =
696+
new ArrayList<>(context.relBuilder.peek().getRowType().getFieldNames());
697+
698+
// Add the binned field with _bin suffix
699+
allFields.add(binExpression);
700+
allFieldNames.add(fieldName + "_bin");
701+
702+
context.relBuilder.project(allFields, allFieldNames);
703+
} else {
704+
// For non-aggregation queries OR single-field binning: replace field with binned value
705+
projectPlusOverriding(List.of(binExpression), List.of(alias), context);
706+
}
672707

673708
return context.relBuilder.peek();
674709
}
@@ -1081,20 +1116,98 @@ private Pair<List<RexNode>, List<AggCall>> resolveAttributesForAggregation(
10811116
CalcitePlanContext context) {
10821117
List<AggCall> aggCallList =
10831118
aggExprList.stream().map(expr -> aggVisitor.analyze(expr, context)).toList();
1084-
List<RexNode> groupByList =
1085-
groupExprList.stream().map(expr -> rexVisitor.analyze(expr, context)).toList();
1086-
return Pair.of(groupByList, aggCallList);
1119+
1120+
// Get available field names in the current relation
1121+
List<String> availableFields = context.relBuilder.peek().getRowType().getFieldNames();
1122+
1123+
// Build group-by list, replacing fields with their _bin columns if they exist
1124+
List<RexNode> groupByList = new ArrayList<>();
1125+
List<AggCall> additionalAggCalls = new ArrayList<>();
1126+
1127+
// Track if we have bin columns - we'll need this to decide whether to add MIN aggregations
1128+
boolean hasBinColumns = false;
1129+
int nonBinGroupByCount = 0;
1130+
1131+
for (UnresolvedExpression groupExpr : groupExprList) {
1132+
RexNode resolvedExpr = rexVisitor.analyze(groupExpr, context);
1133+
1134+
// Extract field name from UnresolvedExpression
1135+
String fieldName = extractFieldName(groupExpr);
1136+
1137+
// Check if this field has a corresponding _bin column
1138+
if (fieldName != null) {
1139+
String binColumnName = fieldName + "_bin";
1140+
if (availableFields.contains(binColumnName)) {
1141+
// Use the _bin column for grouping
1142+
groupByList.add(context.relBuilder.field(binColumnName));
1143+
hasBinColumns = true;
1144+
continue;
1145+
}
1146+
}
1147+
1148+
// Regular group-by field
1149+
groupByList.add(resolvedExpr);
1150+
nonBinGroupByCount++;
1151+
}
1152+
1153+
// Only add MIN aggregations for bin columns if there are OTHER group-by fields
1154+
// This matches OpenSearch behavior:
1155+
// - With multi-field grouping (e.g., by region, timestamp): Show MIN(timestamp) per group
1156+
// - With single-field grouping (e.g., by timestamp only): Show bin start time
1157+
if (hasBinColumns && nonBinGroupByCount > 0) {
1158+
for (UnresolvedExpression groupExpr : groupExprList) {
1159+
String fieldName = extractFieldName(groupExpr);
1160+
if (fieldName != null) {
1161+
String binColumnName = fieldName + "_bin";
1162+
if (availableFields.contains(binColumnName)) {
1163+
// Add MIN(original_field) to show minimum timestamp per bin
1164+
additionalAggCalls.add(
1165+
context.relBuilder.min(context.relBuilder.field(fieldName)).as(fieldName));
1166+
}
1167+
}
1168+
}
1169+
}
1170+
1171+
// Combine original aggregations with additional MIN aggregations for binned fields
1172+
List<AggCall> combinedAggCalls = new ArrayList<>(aggCallList);
1173+
combinedAggCalls.addAll(additionalAggCalls);
1174+
1175+
return Pair.of(groupByList, combinedAggCalls);
10871176
}
10881177

10891178
@Override
10901179
public RelNode visitAggregation(Aggregation node, CalcitePlanContext context) {
1091-
visitChildren(node, context);
1092-
1180+
// Prepare partition columns for bin operations before visiting children
1181+
// This allows WIDTH_BUCKET to use per-group min/max (matching auto_date_histogram)
10931182
List<UnresolvedExpression> aggExprList = node.getAggExprList();
10941183
List<UnresolvedExpression> groupExprList = new ArrayList<>();
1184+
UnresolvedExpression span = node.getSpan();
1185+
if (Objects.nonNull(span)) {
1186+
groupExprList.add(span);
1187+
}
1188+
groupExprList.addAll(node.getGroupExprList());
1189+
1190+
// Store group-by field names and count so bin operations can preserve original fields
1191+
java.util.Set<String> savedGroupByFields = context.getAggregationGroupByFields();
1192+
int savedGroupByCount = context.getAggregationGroupByCount();
1193+
context.setAggregationGroupByFields(new java.util.HashSet<>());
1194+
context.setAggregationGroupByCount(groupExprList.size());
1195+
for (UnresolvedExpression groupExpr : groupExprList) {
1196+
String fieldName = extractFieldName(groupExpr);
1197+
if (fieldName != null) {
1198+
context.getAggregationGroupByFields().add(fieldName);
1199+
}
1200+
}
1201+
1202+
visitChildren(node, context);
1203+
1204+
// Restore previous group-by fields and count
1205+
context.setAggregationGroupByFields(savedGroupByFields);
1206+
context.setAggregationGroupByCount(savedGroupByCount);
1207+
1208+
groupExprList.clear();
10951209
// The span column is always the first column in result whatever
10961210
// the order of span in query is first or last one
1097-
UnresolvedExpression span = node.getSpan();
10981211
if (Objects.nonNull(span)) {
10991212
groupExprList.add(span);
11001213
List<RexNode> timeSpanFilters =
@@ -1139,23 +1252,86 @@ public RelNode visitAggregation(Aggregation node, CalcitePlanContext context) {
11391252
// the sequence of output schema is "count, colA, colB".
11401253
List<RexNode> outputFields = context.relBuilder.fields();
11411254
int numOfOutputFields = outputFields.size();
1142-
int numOfAggList = aggExprList.size();
1255+
int numOfUserAggregations = aggExprList.size();
11431256
List<RexNode> reordered = new ArrayList<>(numOfOutputFields);
1144-
// Add aggregation results first
1145-
List<RexNode> aggRexList =
1146-
outputFields.subList(numOfOutputFields - numOfAggList, numOfOutputFields);
1147-
reordered.addAll(aggRexList);
1148-
// Add group by columns
1149-
List<RexNode> aliasedGroupByList =
1150-
aggregationAttributes.getLeft().stream()
1151-
.map(this::extractAliasLiteral)
1152-
.flatMap(Optional::stream)
1153-
.map(ref -> ref.getValueAs(String.class))
1154-
.map(context.relBuilder::field)
1155-
.map(f -> (RexNode) f)
1156-
.toList();
1157-
reordered.addAll(aliasedGroupByList);
1158-
context.relBuilder.project(reordered);
1257+
1258+
// Add user-specified aggregation results first (exclude MIN aggregations for binned fields)
1259+
List<RexNode> userAggRexList =
1260+
outputFields.subList(
1261+
numOfOutputFields - aggregationAttributes.getRight().size(),
1262+
numOfOutputFields - aggregationAttributes.getRight().size() + numOfUserAggregations);
1263+
1264+
// Wrap AVG aggregations with ROUND to fix floating point precision
1265+
for (int i = 0; i < userAggRexList.size(); i++) {
1266+
RexNode aggRex = userAggRexList.get(i);
1267+
UnresolvedExpression aggExpr = aggExprList.get(i);
1268+
1269+
// Unwrap Alias to get to the actual aggregation function
1270+
UnresolvedExpression actualAggExpr = aggExpr;
1271+
if (aggExpr instanceof org.opensearch.sql.ast.expression.Alias) {
1272+
actualAggExpr = ((org.opensearch.sql.ast.expression.Alias) aggExpr).getDelegated();
1273+
}
1274+
1275+
// Check if this is an AVG aggregation
1276+
if (actualAggExpr instanceof org.opensearch.sql.ast.expression.AggregateFunction) {
1277+
org.opensearch.sql.ast.expression.AggregateFunction aggFunc =
1278+
(org.opensearch.sql.ast.expression.AggregateFunction) actualAggExpr;
1279+
if ("avg".equalsIgnoreCase(aggFunc.getFuncName())) {
1280+
// Wrap with ROUND(value, 2)
1281+
aggRex =
1282+
context.relBuilder.call(
1283+
org.apache.calcite.sql.fun.SqlStdOperatorTable.ROUND,
1284+
aggRex,
1285+
context.rexBuilder.makeLiteral(
1286+
2,
1287+
context
1288+
.relBuilder
1289+
.getTypeFactory()
1290+
.createSqlType(org.apache.calcite.sql.type.SqlTypeName.INTEGER),
1291+
false));
1292+
}
1293+
}
1294+
reordered.add(aggRex);
1295+
}
1296+
1297+
// Add group by columns, replacing _bin columns with their MIN aggregations
1298+
// Get field names from the aggregate output (group-by fields come first)
1299+
List<String> allFieldNames = context.relBuilder.peek().getRowType().getFieldNames();
1300+
int numGroupByFields = aggregationAttributes.getLeft().size();
1301+
1302+
List<String> outputFieldNames = new ArrayList<>();
1303+
1304+
for (int i = 0; i < numGroupByFields; i++) {
1305+
String fieldName = allFieldNames.get(i);
1306+
if (fieldName.endsWith("_bin")) {
1307+
// This is a bin column
1308+
String originalFieldName = fieldName.substring(0, fieldName.length() - 4); // Remove "_bin"
1309+
// Check if we have a MIN aggregation for this field (only present for multi-field grouping)
1310+
if (allFieldNames.contains(originalFieldName)) {
1311+
// Use the MIN aggregation
1312+
reordered.add(context.relBuilder.field(originalFieldName));
1313+
outputFieldNames.add(originalFieldName);
1314+
} else {
1315+
// Use the bin column directly (for single-field binning) - rename to original name
1316+
reordered.add(context.relBuilder.field(fieldName));
1317+
outputFieldNames.add(originalFieldName); // Rename _bin field to original name
1318+
}
1319+
} else {
1320+
// Regular group-by field
1321+
reordered.add(context.relBuilder.field(fieldName));
1322+
outputFieldNames.add(fieldName);
1323+
}
1324+
}
1325+
1326+
// Add aggregation field names (after group-by fields in the reordered list)
1327+
// The user aggregations are at the beginning of reordered list, so we add their names
1328+
int aggStartIndex = numOfOutputFields - aggregationAttributes.getRight().size();
1329+
for (int i = aggStartIndex; i < aggStartIndex + numOfUserAggregations; i++) {
1330+
outputFieldNames.add(
1331+
0, allFieldNames.get(i)); // Add at beginning to match reordered list order
1332+
}
1333+
1334+
context.relBuilder.project(reordered, outputFieldNames);
11591335

11601336
return context.relBuilder.peek();
11611337
}

0 commit comments

Comments
 (0)