diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/Optimizer.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/Optimizer.java index c5073b01fd9c2..ec284957c8456 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/Optimizer.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/Optimizer.java @@ -60,6 +60,7 @@ import com.starrocks.sql.optimizer.rule.transformation.PushLimitAndFilterToCTEProduceRule; import com.starrocks.sql.optimizer.rule.transformation.RemoveAggregationFromAggTable; import com.starrocks.sql.optimizer.rule.transformation.RewriteGroupingSetsByCTERule; +import com.starrocks.sql.optimizer.rule.transformation.RewriteMultiDistinctRule; import com.starrocks.sql.optimizer.rule.transformation.RewriteSimpleAggToMetaScanRule; import com.starrocks.sql.optimizer.rule.transformation.SeparateProjectRule; import com.starrocks.sql.optimizer.rule.transformation.SkewJoinOptimizeRule; @@ -427,7 +428,6 @@ private OptExpression logicalRuleRewrite( deriveLogicalProperty(tree); } - ruleRewriteIterative(tree, rootTaskContext, RuleSetType.MULTI_DISTINCT_REWRITE); ruleRewriteIterative(tree, rootTaskContext, RuleSetType.PUSH_DOWN_PREDICATE); // No heavy metadata operation before external table partition prune @@ -436,6 +436,7 @@ private OptExpression logicalRuleRewrite( // rewrite before SplitScanORToUnionRule ruleRewriteOnlyOnce(tree, rootTaskContext, new SplitDatePredicateRule()); ruleRewriteOnlyOnce(tree, rootTaskContext, RuleSetType.PARTITION_PRUNE); + ruleRewriteIterative(tree, rootTaskContext, new RewriteMultiDistinctRule()); ruleRewriteIterative(tree, rootTaskContext, RuleSetType.PRUNE_EMPTY_OPERATOR); ruleRewriteIterative(tree, rootTaskContext, RuleSetType.PRUNE_PROJECT); diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/Utils.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/Utils.java index 7b91c92e776d7..b363d42fcd2a4 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/Utils.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/Utils.java @@ -52,8 +52,7 @@ import com.starrocks.sql.optimizer.rewrite.ReplaceColumnRefRewriter; import com.starrocks.sql.optimizer.rewrite.ScalarOperatorRewriter; import com.starrocks.sql.optimizer.statistics.ColumnStatistic; -import org.apache.commons.collections.CollectionUtils; -import org.apache.commons.collections.MapUtils; +import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.collections4.SetUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -715,17 +714,26 @@ public static boolean mustGenerateMultiStageAggregate(Operator inputOp, Operator aggs = ((PhysicalHashAggregateOperator) inputOp).getAggregations(); } - if (MapUtils.isEmpty(aggs)) { - return false; - } else { - // Must do multiple stage aggregate when aggregate distinct function has array type - // Must generate three, four phase aggregate for distinct aggregate with multi columns - return aggs.values().stream().anyMatch(callOperator -> callOperator.isDistinct() - && (callOperator.getChildren().size() > 1 || - callOperator.getChildren().stream().anyMatch(c -> c.getType().isComplexType()))); + for (CallOperator callOperator : aggs.values()) { + if (callOperator.isDistinct()) { + String fnName = callOperator.getFnName(); + List children = callOperator.getChildren(); + if (children.size() > 1 || children.stream().anyMatch(c -> c.getType().isComplexType())) { + return true; + } + if (FunctionSet.GROUP_CONCAT.equalsIgnoreCase(fnName) || FunctionSet.AVG.equalsIgnoreCase(fnName)) { + return true; + } else if (FunctionSet.ARRAY_AGG.equalsIgnoreCase(fnName)) { + if (children.size() > 1 || children.get(0).getType().isDecimalOfAnyVersion()) { + return true; + } + } + } } + return false; } + // without distinct function, the common distinctCols is an empty list. public static Optional> extractCommonDistinctCols(Collection aggCallOperators) { Set distinctChildren = Sets.newHashSet(); for (CallOperator callOperator : aggCallOperators) { diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/scalar/ScalarOperatorUtil.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/scalar/ScalarOperatorUtil.java index aeb0629fe91c4..717a0a3c201d4 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/scalar/ScalarOperatorUtil.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/scalar/ScalarOperatorUtil.java @@ -23,7 +23,6 @@ import com.starrocks.catalog.ScalarType; import com.starrocks.catalog.Type; import com.starrocks.server.GlobalStateMgr; -import com.starrocks.sql.analyzer.DecimalV3FunctionAnalyzer; import com.starrocks.sql.optimizer.rewrite.ScalarOperatorRewriter; import static com.starrocks.catalog.Function.CompareMode.IS_IDENTICAL; @@ -76,14 +75,4 @@ public static Function findSumFn(Type[] argTypes) { } return newFn; } - - public static CallOperator buildMultiSumDistinct(CallOperator oldFunctionCall) { - Function multiDistinctSum = DecimalV3FunctionAnalyzer.convertSumToMultiDistinctSum( - oldFunctionCall.getFunction(), oldFunctionCall.getChild(0).getType()); - ScalarOperatorRewriter scalarOpRewriter = new ScalarOperatorRewriter(); - return (CallOperator) scalarOpRewriter.rewrite( - new CallOperator( - FunctionSet.MULTI_DISTINCT_SUM, multiDistinctSum.getReturnType(), - oldFunctionCall.getChildren(), multiDistinctSum), DEFAULT_TYPE_CAST_RULE); - } } diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleSet.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleSet.java index fd5cf53aeb954..e981b3c6e5065 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleSet.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleSet.java @@ -139,8 +139,6 @@ import com.starrocks.sql.optimizer.rule.transformation.RewriteCountIfFunction; import com.starrocks.sql.optimizer.rule.transformation.RewriteDuplicateAggregateFnRule; import com.starrocks.sql.optimizer.rule.transformation.RewriteHllCountDistinctRule; -import com.starrocks.sql.optimizer.rule.transformation.RewriteMultiDistinctByCTERule; -import com.starrocks.sql.optimizer.rule.transformation.RewriteMultiDistinctRule; import com.starrocks.sql.optimizer.rule.transformation.RewriteSimpleAggToMetaScanRule; import com.starrocks.sql.optimizer.rule.transformation.RewriteSumByAssociativeRule; import com.starrocks.sql.optimizer.rule.transformation.ScalarApply2AnalyticRule; @@ -361,11 +359,6 @@ public class RuleSet { new RewriteCountIfFunction() )); - REWRITE_RULES.put(RuleSetType.MULTI_DISTINCT_REWRITE, ImmutableList.of( - new RewriteMultiDistinctByCTERule(), - new RewriteMultiDistinctRule() - )); - REWRITE_RULES.put(RuleSetType.PRUNE_PROJECT, ImmutableList.of( new PruneProjectRule(), new PruneProjectEmptyRule(), diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/RewriteMultiDistinctByCTERule.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/MultiDistinctByCTERewriter.java similarity index 92% rename from fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/RewriteMultiDistinctByCTERule.java rename to fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/MultiDistinctByCTERewriter.java index 5c90c4fc149b8..6c239013741b6 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/RewriteMultiDistinctByCTERule.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/MultiDistinctByCTERewriter.java @@ -23,7 +23,6 @@ import com.starrocks.catalog.Function; import com.starrocks.catalog.FunctionSet; import com.starrocks.catalog.ScalarType; -import com.starrocks.qe.ConnectContext; import com.starrocks.server.GlobalStateMgr; import com.starrocks.sql.analyzer.DecimalV3FunctionAnalyzer; import com.starrocks.sql.optimizer.ExpressionContext; @@ -34,7 +33,6 @@ import com.starrocks.sql.optimizer.base.ColumnRefSet; import com.starrocks.sql.optimizer.operator.Operator; import com.starrocks.sql.optimizer.operator.OperatorBuilderFactory; -import com.starrocks.sql.optimizer.operator.OperatorType; import com.starrocks.sql.optimizer.operator.logical.LogicalAggregationOperator; import com.starrocks.sql.optimizer.operator.logical.LogicalCTEAnchorOperator; import com.starrocks.sql.optimizer.operator.logical.LogicalCTEConsumeOperator; @@ -42,7 +40,6 @@ import com.starrocks.sql.optimizer.operator.logical.LogicalFilterOperator; import com.starrocks.sql.optimizer.operator.logical.LogicalJoinOperator; import com.starrocks.sql.optimizer.operator.logical.LogicalProjectOperator; -import com.starrocks.sql.optimizer.operator.pattern.Pattern; import com.starrocks.sql.optimizer.operator.scalar.BinaryPredicateOperator; import com.starrocks.sql.optimizer.operator.scalar.CallOperator; import com.starrocks.sql.optimizer.operator.scalar.CastOperator; @@ -51,7 +48,6 @@ import com.starrocks.sql.optimizer.rewrite.ReplaceColumnRefRewriter; import com.starrocks.sql.optimizer.rewrite.ScalarOperatorRewriter; import com.starrocks.sql.optimizer.rewrite.scalar.ImplicitCastRule; -import com.starrocks.sql.optimizer.rule.RuleType; import java.util.LinkedList; import java.util.List; @@ -123,48 +119,11 @@ * * */ -public class RewriteMultiDistinctByCTERule extends TransformationRule { - private final ScalarOperatorRewriter scalarRewriter = new ScalarOperatorRewriter(); - - public RewriteMultiDistinctByCTERule() { - super(RuleType.TF_REWRITE_MULTI_DISTINCT_BY_CTE, - Pattern.create(OperatorType.LOGICAL_AGGR).addChildren(Pattern.create( - OperatorType.PATTERN_LEAF))); - } - - @Override - public boolean check(OptExpression input, OptimizerContext context) { - // check cte is disabled or hasNoGroup false - LogicalAggregationOperator agg = (LogicalAggregationOperator) input.getOp(); - List distinctAggOperatorList = agg.getAggregations().values().stream() - .filter(CallOperator::isDistinct).collect(Collectors.toList()); - boolean hasMultiColumns = distinctAggOperatorList.stream().anyMatch(f -> f.getDistinctChildren().size() > 1); - if (hasMultiColumns && distinctAggOperatorList.size() > 1) { - return true; - } - - if (!context.getSessionVariable().isCboCteReuse()) { - return false; - } - - if (agg.hasSkew() && distinctAggOperatorList.size() > 1 && !agg.getGroupingKeys().isEmpty()) { - return true; - } - - if (agg.hasLimit() && !ConnectContext.get().getSessionVariable().isPreferCTERewrite()) { - return false; - } - - if (!hasMultiColumns && agg.getGroupingKeys().size() > 1) { - return false; - } +public class MultiDistinctByCTERewriter { - return distinctAggOperatorList.size() > 1 || agg.getAggregations().values().stream() - .anyMatch(call -> call.isDistinct() && call.getFnName().equals(FunctionSet.AVG)); - } + private final ScalarOperatorRewriter scalarRewriter = new ScalarOperatorRewriter(); - @Override - public List transform(OptExpression input, OptimizerContext context) { + public List transformImpl(OptExpression input, OptimizerContext context) { ColumnRefFactory columnRefFactory = context.getColumnRefFactory(); // define cteId int cteId = context.getCteContext().getNextCteId(); @@ -453,4 +412,5 @@ private LogicalCTEConsumeOperator buildCteConsume(OptExpression cteProduce, Colu return new LogicalCTEConsumeOperator(cteId, consumeOutputMap); } + } diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/MultiDistinctByMultiFuncRewriter.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/MultiDistinctByMultiFuncRewriter.java new file mode 100644 index 0000000000000..455d09544c593 --- /dev/null +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/MultiDistinctByMultiFuncRewriter.java @@ -0,0 +1,194 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.starrocks.sql.optimizer.rule.transformation; + +import com.google.common.collect.Lists; +import com.starrocks.analysis.FunctionName; +import com.starrocks.catalog.Function; +import com.starrocks.catalog.FunctionSet; +import com.starrocks.catalog.ScalarType; +import com.starrocks.catalog.Type; +import com.starrocks.server.GlobalStateMgr; +import com.starrocks.sql.analyzer.DecimalV3FunctionAnalyzer; +import com.starrocks.sql.optimizer.OptExpression; +import com.starrocks.sql.optimizer.OptimizerContext; +import com.starrocks.sql.optimizer.operator.AggType; +import com.starrocks.sql.optimizer.operator.logical.LogicalAggregationOperator; +import com.starrocks.sql.optimizer.operator.logical.LogicalFilterOperator; +import com.starrocks.sql.optimizer.operator.logical.LogicalProjectOperator; +import com.starrocks.sql.optimizer.operator.scalar.CallOperator; +import com.starrocks.sql.optimizer.operator.scalar.CastOperator; +import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator; +import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator; +import com.starrocks.sql.optimizer.rewrite.ScalarOperatorRewriter; +import com.starrocks.sql.optimizer.rewrite.scalar.ImplicitCastRule; +import com.starrocks.sql.optimizer.rewrite.scalar.ScalarOperatorRewriteRule; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static com.starrocks.catalog.Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF; + +public class MultiDistinctByMultiFuncRewriter { + + private static final List DEFAULT_TYPE_CAST_RULE = Lists.newArrayList( + new ImplicitCastRule() + ); + private final ScalarOperatorRewriter scalarRewriter = new ScalarOperatorRewriter(); + + public List transformImpl(OptExpression input, OptimizerContext context) { + LogicalAggregationOperator aggregationOperator = (LogicalAggregationOperator) input.getOp(); + + Map newAggMap = new HashMap<>(); + for (Map.Entry aggregation : aggregationOperator.getAggregations() + .entrySet()) { + CallOperator oldFunctionCall = aggregation.getValue(); + if (oldFunctionCall.isDistinct()) { + CallOperator newAggOperator; + + if (oldFunctionCall.getFnName().equalsIgnoreCase(FunctionSet.COUNT)) { + newAggOperator = buildMultiCountDistinct(oldFunctionCall); + } else if (oldFunctionCall.getFnName().equalsIgnoreCase(FunctionSet.SUM)) { + newAggOperator = buildMultiSumDistinct(oldFunctionCall); + } else if (oldFunctionCall.getFnName().equals(FunctionSet.ARRAY_AGG)) { + if (oldFunctionCall.getColumnRefs().size() == 1 && + !oldFunctionCall.getColumnRefs().get(0).getType().isDecimalOfAnyVersion()) { + newAggOperator = buildArrayAggDistinct(oldFunctionCall); + } else { + newAggOperator = oldFunctionCall; + } + } else { + newAggOperator = oldFunctionCall; + } + newAggMap.put(aggregation.getKey(), newAggOperator); + } else { + newAggMap.put(aggregation.getKey(), aggregation.getValue()); + } + } + + /* + * Repeat the loop once, because avg can use the newly generated aggregate function last time, + * so that the expression can be reused. such as: count(distinct v1), avg(distinct v1), sum(distinct v1), + * avg can use multi_distinct_x generated by count or sum + */ + boolean hasAvg = false; + Map projections = new HashMap<>(); + Map newAggMapWithAvg = new HashMap<>(); + for (Map.Entry aggMap : newAggMap.entrySet()) { + CallOperator oldFunctionCall = aggMap.getValue(); + if (oldFunctionCall.isDistinct() && oldFunctionCall.getFnName().equals(FunctionSet.AVG)) { + hasAvg = true; + CallOperator count = buildMultiCountDistinct(oldFunctionCall); + ColumnRefOperator countColRef = null; + for (Map.Entry entry : newAggMap.entrySet()) { + if (entry.getValue().equals(count)) { + countColRef = entry.getKey(); + break; + } + } + countColRef = countColRef == null ? + context.getColumnRefFactory().create(count, count.getType(), count.isNullable()) : countColRef; + newAggMapWithAvg.put(countColRef, count); + + CallOperator sum = buildMultiSumDistinct(oldFunctionCall); + ColumnRefOperator sumColRef = null; + for (Map.Entry entry : newAggMap.entrySet()) { + if (entry.getValue().equals(sum)) { + sumColRef = entry.getKey(); + break; + } + } + sumColRef = sumColRef == null ? + context.getColumnRefFactory().create(sum, sum.getType(), sum.isNullable()) : sumColRef; + newAggMapWithAvg.put(sumColRef, sum); + CallOperator multiAvg = new CallOperator(FunctionSet.DIVIDE, oldFunctionCall.getType(), + Lists.newArrayList(sumColRef, countColRef)); + if (multiAvg.getType().isDecimalV3()) { + // There is not need to apply ImplicitCastRule to divide operator of decimal types. + // but we should cast BIGINT-typed countColRef into DECIMAL(38,0). + ScalarType decimal128p38s0 = ScalarType.createDecimalV3NarrowestType(38, 0); + multiAvg.getChildren().set( + 1, new CastOperator(decimal128p38s0, multiAvg.getChild(1), true)); + } else { + multiAvg = (CallOperator) scalarRewriter.rewrite(multiAvg, + Lists.newArrayList(new ImplicitCastRule())); + } + projections.put(aggMap.getKey(), multiAvg); + } else { + projections.put(aggMap.getKey(), aggMap.getKey()); + newAggMapWithAvg.put(aggMap.getKey(), aggMap.getValue()); + } + } + + OptExpression result; + if (hasAvg) { + OptExpression aggOpt = OptExpression + .create(new LogicalAggregationOperator.Builder().withOperator(aggregationOperator) + .setType(AggType.GLOBAL) + .setAggregations(newAggMapWithAvg) + .setPredicate(null) + .build(), + input.getInputs()); + aggregationOperator.getGroupingKeys().forEach(c -> projections.put(c, c)); + result = OptExpression.create(new LogicalProjectOperator(projections), Lists.newArrayList(aggOpt)); + } else { + result = OptExpression + .create(new LogicalAggregationOperator.Builder().withOperator(aggregationOperator) + .setType(AggType.GLOBAL) + .setAggregations(newAggMap) + .setPredicate(null) + .build(), + input.getInputs()); + } + + if (aggregationOperator.getPredicate() != null) { + result = OptExpression.create(new LogicalFilterOperator(aggregationOperator.getPredicate()), result); + } + + return Lists.newArrayList(result); + } + + private CallOperator buildMultiCountDistinct(CallOperator oldFunctionCall) { + Function searchDesc = new Function(new FunctionName(FunctionSet.MULTI_DISTINCT_COUNT), + oldFunctionCall.getFunction().getArgs(), Type.INVALID, false); + Function fn = GlobalStateMgr.getCurrentState().getFunction(searchDesc, IS_NONSTRICT_SUPERTYPE_OF); + + return (CallOperator) scalarRewriter.rewrite( + new CallOperator(FunctionSet.MULTI_DISTINCT_COUNT, fn.getReturnType(), oldFunctionCall.getChildren(), + fn), + DEFAULT_TYPE_CAST_RULE); + } + + private CallOperator buildArrayAggDistinct(CallOperator oldFunctionCall) { + Function searchDesc = new Function(new FunctionName(FunctionSet.ARRAY_AGG_DISTINCT), + oldFunctionCall.getFunction().getArgs(), Type.INVALID, false); + Function fn = GlobalStateMgr.getCurrentState().getFunction(searchDesc, IS_NONSTRICT_SUPERTYPE_OF); + + return (CallOperator) scalarRewriter.rewrite( + new CallOperator(FunctionSet.ARRAY_AGG_DISTINCT, fn.getReturnType(), oldFunctionCall.getChildren(), + fn), + DEFAULT_TYPE_CAST_RULE); + } + + private CallOperator buildMultiSumDistinct(CallOperator oldFunctionCall) { + Function multiDistinctSum = DecimalV3FunctionAnalyzer.convertSumToMultiDistinctSum( + oldFunctionCall.getFunction(), oldFunctionCall.getChild(0).getType()); + return (CallOperator) scalarRewriter.rewrite( + new CallOperator( + FunctionSet.MULTI_DISTINCT_SUM, multiDistinctSum.getReturnType(), + oldFunctionCall.getChildren(), multiDistinctSum), DEFAULT_TYPE_CAST_RULE); + } +} diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/RewriteMultiDistinctRule.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/RewriteMultiDistinctRule.java index e4069f2211433..26cb23b54b4e3 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/RewriteMultiDistinctRule.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/RewriteMultiDistinctRule.java @@ -15,42 +15,35 @@ package com.starrocks.sql.optimizer.rule.transformation; import com.google.common.collect.Lists; -import com.starrocks.analysis.FunctionName; -import com.starrocks.catalog.Function; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; import com.starrocks.catalog.FunctionSet; -import com.starrocks.catalog.ScalarType; import com.starrocks.catalog.Type; -import com.starrocks.server.GlobalStateMgr; -import com.starrocks.sql.analyzer.DecimalV3FunctionAnalyzer; +import com.starrocks.sql.common.ErrorType; +import com.starrocks.sql.common.StarRocksPlannerException; +import com.starrocks.sql.optimizer.ExpressionContext; import com.starrocks.sql.optimizer.OptExpression; import com.starrocks.sql.optimizer.OptimizerContext; -import com.starrocks.sql.optimizer.operator.AggType; +import com.starrocks.sql.optimizer.Utils; import com.starrocks.sql.optimizer.operator.OperatorType; import com.starrocks.sql.optimizer.operator.logical.LogicalAggregationOperator; -import com.starrocks.sql.optimizer.operator.logical.LogicalFilterOperator; -import com.starrocks.sql.optimizer.operator.logical.LogicalProjectOperator; import com.starrocks.sql.optimizer.operator.pattern.Pattern; import com.starrocks.sql.optimizer.operator.scalar.CallOperator; -import com.starrocks.sql.optimizer.operator.scalar.CastOperator; import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator; import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator; -import com.starrocks.sql.optimizer.rewrite.ScalarOperatorRewriter; -import com.starrocks.sql.optimizer.rewrite.scalar.ImplicitCastRule; -import com.starrocks.sql.optimizer.rewrite.scalar.ScalarOperatorRewriteRule; import com.starrocks.sql.optimizer.rule.RuleType; +import com.starrocks.sql.optimizer.statistics.Statistics; +import com.starrocks.sql.optimizer.statistics.StatisticsCalculator; -import java.util.HashMap; import java.util.List; -import java.util.Map; +import java.util.Optional; +import java.util.Set; import java.util.stream.Collectors; -import static com.starrocks.catalog.Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF; +import static com.starrocks.sql.optimizer.statistics.StatisticsEstimateCoefficient.LOW_AGGREGATE_EFFECT_COEFFICIENT; +import static com.starrocks.sql.optimizer.statistics.StatisticsEstimateCoefficient.MEDIUM_AGGREGATE_EFFECT_COEFFICIENT; public class RewriteMultiDistinctRule extends TransformationRule { - private static final List DEFAULT_TYPE_CAST_RULE = Lists.newArrayList( - new ImplicitCastRule() - ); - private final ScalarOperatorRewriter scalarRewriter = new ScalarOperatorRewriter(); public RewriteMultiDistinctRule() { super(RuleType.TF_REWRITE_MULTI_DISTINCT, @@ -62,149 +55,131 @@ public RewriteMultiDistinctRule() { public boolean check(OptExpression input, OptimizerContext context) { LogicalAggregationOperator agg = (LogicalAggregationOperator) input.getOp(); - List distinctAggOperatorList = agg.getAggregations().values().stream() - .filter(CallOperator::isDistinct).collect(Collectors.toList()); + Optional> distinctCols = Utils.extractCommonDistinctCols(agg.getAggregations().values()); - boolean hasMultiColumns = distinctAggOperatorList.stream().anyMatch(f -> f.getDistinctChildren().size() > 1); - return (distinctAggOperatorList.size() > 1 || agg.getAggregations().values().stream() - .anyMatch(call -> call.isDistinct() && call.getFnName().equals(FunctionSet.AVG))) && !hasMultiColumns; + // all distinct function use the same distinct columns, we use the split rule to rewrite + return !distinctCols.isPresent(); } - @Override public List transform(OptExpression input, OptimizerContext context) { - LogicalAggregationOperator aggregationOperator = (LogicalAggregationOperator) input.getOp(); - - Map newAggMap = new HashMap<>(); - for (Map.Entry aggregation : aggregationOperator.getAggregations() - .entrySet()) { - CallOperator oldFunctionCall = aggregation.getValue(); - if (oldFunctionCall.isDistinct()) { - CallOperator newAggOperator; - if (oldFunctionCall.getFnName().equalsIgnoreCase(FunctionSet.COUNT)) { - newAggOperator = buildMultiCountDistinct(oldFunctionCall); - } else if (oldFunctionCall.getFnName().equalsIgnoreCase(FunctionSet.SUM)) { - newAggOperator = buildMultiSumDistinct(oldFunctionCall); - } else if (oldFunctionCall.getFnName().equals(FunctionSet.ARRAY_AGG)) { - newAggOperator = buildArrayAggDistinct(oldFunctionCall); - } else if (oldFunctionCall.getFnName().equalsIgnoreCase(FunctionSet.AVG)) { - newAggOperator = oldFunctionCall; - } else { - return Lists.newArrayList(); - } - newAggMap.put(aggregation.getKey(), newAggOperator); - } else { - newAggMap.put(aggregation.getKey(), aggregation.getValue()); - } + if (useCteToRewrite(input, context)) { + MultiDistinctByCTERewriter rewriter = new MultiDistinctByCTERewriter(); + return rewriter.transformImpl(input, context); + } else { + MultiDistinctByMultiFuncRewriter rewriter = new MultiDistinctByMultiFuncRewriter(); + return rewriter.transformImpl(input, context); } + } - /* - * Repeat the loop once, because avg can use the newly generated aggregate function last time, - * so that the expression can be reused. such as: count(distinct v1), avg(distinct v1), sum(distinct v1), - * avg can use multi_distinct_x generated by count or sum - */ - boolean hasAvg = false; - Map projections = new HashMap<>(); - Map newAggMapWithAvg = new HashMap<>(); - for (Map.Entry aggMap : newAggMap.entrySet()) { - CallOperator oldFunctionCall = aggMap.getValue(); - if (oldFunctionCall.isDistinct() && oldFunctionCall.getFnName().equals(FunctionSet.AVG)) { - hasAvg = true; - CallOperator count = buildMultiCountDistinct(oldFunctionCall); - ColumnRefOperator countColRef = null; - for (Map.Entry entry : newAggMap.entrySet()) { - if (entry.getValue().equals(count)) { - countColRef = entry.getKey(); - break; - } - } - countColRef = countColRef == null ? - context.getColumnRefFactory().create(count, count.getType(), count.isNullable()) : countColRef; - newAggMapWithAvg.put(countColRef, count); - - CallOperator sum = buildMultiSumDistinct(oldFunctionCall); - ColumnRefOperator sumColRef = null; - for (Map.Entry entry : newAggMap.entrySet()) { - if (entry.getValue().equals(sum)) { - sumColRef = entry.getKey(); - break; - } - } - sumColRef = sumColRef == null ? - context.getColumnRefFactory().create(sum, sum.getType(), sum.isNullable()) : sumColRef; - newAggMapWithAvg.put(sumColRef, sum); - CallOperator multiAvg = new CallOperator(FunctionSet.DIVIDE, oldFunctionCall.getType(), - Lists.newArrayList(sumColRef, countColRef)); - if (multiAvg.getType().isDecimalV3()) { - // There is not need to apply ImplicitCastRule to divide operator of decimal types. - // but we should cast BIGINT-typed countColRef into DECIMAL(38,0). - ScalarType decimal128p38s0 = ScalarType.createDecimalV3NarrowestType(38, 0); - multiAvg.getChildren().set( - 1, new CastOperator(decimal128p38s0, multiAvg.getChild(1), true)); - } else { - multiAvg = (CallOperator) scalarRewriter.rewrite(multiAvg, - Lists.newArrayList(new ImplicitCastRule())); - } - projections.put(aggMap.getKey(), multiAvg); + private boolean useCteToRewrite(OptExpression input, OptimizerContext context) { + LogicalAggregationOperator agg = (LogicalAggregationOperator) input.getOp(); + List distinctAggOperatorList = agg.getAggregations().values().stream() + .filter(CallOperator::isDistinct).collect(Collectors.toList()); + boolean hasMultiColumns = distinctAggOperatorList.stream().anyMatch(f -> f.getColumnRefs().size() > 1); + // exist multiple distinct columns should enable cte use + if (hasMultiColumns) { + if (!context.getSessionVariable().isCboCteReuse()) { + throw new StarRocksPlannerException(ErrorType.USER_ERROR, + "%s is unsupported when cbo_cte_reuse is disabled", distinctAggOperatorList); } else { - projections.put(aggMap.getKey(), aggMap.getKey()); - newAggMapWithAvg.put(aggMap.getKey(), aggMap.getValue()); + return true; } } - OptExpression result; - if (hasAvg) { - OptExpression aggOpt = OptExpression - .create(new LogicalAggregationOperator.Builder().withOperator(aggregationOperator) - .setType(AggType.GLOBAL) - .setAggregations(newAggMapWithAvg) - .build(), - input.getInputs()); - aggregationOperator.getGroupingKeys().forEach(c -> projections.put(c, c)); - result = OptExpression.create(new LogicalProjectOperator(projections), Lists.newArrayList(aggOpt)); - } else { - result = OptExpression - .create(new LogicalAggregationOperator.Builder().withOperator(aggregationOperator) - .setType(AggType.GLOBAL) - .setAggregations(newAggMap) - .build(), - input.getInputs()); + // respect prefer cte rewrite hint + if (context.getSessionVariable().isCboCteReuse() && context.getSessionVariable().isPreferCTERewrite()) { + return true; } - if (aggregationOperator.getPredicate() != null) { - result = OptExpression.create(new LogicalFilterOperator(aggregationOperator.getPredicate()), result); + // respect skew int + if (context.getSessionVariable().isCboCteReuse() && agg.hasSkew() && !agg.getGroupingKeys().isEmpty()) { + return true; } - return Lists.newArrayList(result); - } + if (context.getSessionVariable().isCboCteReuse() && + isCTEMoreEfficient(input, context, distinctAggOperatorList)) { + return true; + } - private CallOperator buildMultiCountDistinct(CallOperator oldFunctionCall) { - Function searchDesc = new Function(new FunctionName(FunctionSet.MULTI_DISTINCT_COUNT), - oldFunctionCall.getFunction().getArgs(), Type.INVALID, false); - Function fn = GlobalStateMgr.getCurrentState().getFunction(searchDesc, IS_NONSTRICT_SUPERTYPE_OF); + // all distinct one column function can be rewritten by multi distinct function + boolean canRewriteByMultiFunc = true; + for (CallOperator distinctCall : distinctAggOperatorList) { + String fnName = distinctCall.getFnName(); + List children = distinctCall.getChildren(); + Type type = children.get(0).getType(); + if (type.isComplexType() + || type.isJsonType() + || FunctionSet.GROUP_CONCAT.equalsIgnoreCase(fnName) + || (FunctionSet.ARRAY_AGG.equalsIgnoreCase(fnName) && type.isDecimalOfAnyVersion())) { + canRewriteByMultiFunc = false; + break; + } + } - return (CallOperator) scalarRewriter.rewrite( - new CallOperator(FunctionSet.MULTI_DISTINCT_COUNT, fn.getReturnType(), oldFunctionCall.getChildren(), - fn), - DEFAULT_TYPE_CAST_RULE); + if (!context.getSessionVariable().isCboCteReuse() && !canRewriteByMultiFunc) { + throw new StarRocksPlannerException(ErrorType.USER_ERROR, + "%s is unsupported when cbo_cte_reuse is disabled", distinctAggOperatorList); + } + + return !canRewriteByMultiFunc; } - private CallOperator buildArrayAggDistinct(CallOperator oldFunctionCall) { - Function searchDesc = new Function(new FunctionName(FunctionSet.ARRAY_AGG_DISTINCT), - oldFunctionCall.getFunction().getArgs(), Type.INVALID, false); - Function fn = GlobalStateMgr.getCurrentState().getFunction(searchDesc, IS_NONSTRICT_SUPERTYPE_OF); + private boolean isCTEMoreEfficient(OptExpression input, OptimizerContext context, + List distinctAggOperatorList) { + LogicalAggregationOperator aggOp = input.getOp().cast(); + if (aggOp.hasLimit()) { + return false; + } + calculateStatistics(input, context); + + Statistics inputStatistics = input.inputAt(0).getStatistics(); + List neededCols = Lists.newArrayList(aggOp.getGroupingKeys()); + distinctAggOperatorList.stream().forEach(e -> neededCols.addAll(e.getColumnRefs())); + + // no statistics available, use cte for no group by or group by only one col scenes to avoid bad case of multiple_func + if (neededCols.stream().anyMatch(e -> inputStatistics.getColumnStatistics().get(e).isUnknown())) { + return aggOp.getGroupingKeys().size() < 2; + } + + double inputRowCount = inputStatistics.getOutputRowCount(); + List deduplicateOutputRows = Lists.newArrayList(); + List distinctValueCounts = Lists.newArrayList(); + for (CallOperator callOperator : distinctAggOperatorList) { + List distinctColumns = callOperator.getColumnRefs(); + if (distinctColumns.isEmpty()) { + continue; + } + Set deduplicateKeys = Sets.newHashSet(); + deduplicateKeys.addAll(aggOp.getGroupingKeys()); + deduplicateKeys.addAll(distinctColumns); + deduplicateOutputRows.add(StatisticsCalculator.computeGroupByStatistics(Lists.newArrayList(deduplicateKeys), + inputStatistics, Maps.newHashMap())); + distinctValueCounts.add(inputStatistics.getColumnStatistics().get(distinctColumns.get(0)).getDistinctValuesCount()); + } - return (CallOperator) scalarRewriter.rewrite( - new CallOperator(FunctionSet.ARRAY_AGG_DISTINCT, fn.getReturnType(), oldFunctionCall.getChildren(), - fn), - DEFAULT_TYPE_CAST_RULE); + if (distinctValueCounts.stream().allMatch(d -> d < MEDIUM_AGGREGATE_EFFECT_COEFFICIENT)) { + // distinct key with an extreme low cardinality use multi_distinct_func maybe more efficient + return false; + } else if (deduplicateOutputRows.stream().allMatch(row -> row * LOW_AGGREGATE_EFFECT_COEFFICIENT < inputRowCount)) { + return false; + } + return true; } - private CallOperator buildMultiSumDistinct(CallOperator oldFunctionCall) { - Function multiDistinctSum = DecimalV3FunctionAnalyzer.convertSumToMultiDistinctSum( - oldFunctionCall.getFunction(), oldFunctionCall.getChild(0).getType()); - return (CallOperator) scalarRewriter.rewrite( - new CallOperator( - FunctionSet.MULTI_DISTINCT_SUM, multiDistinctSum.getReturnType(), - oldFunctionCall.getChildren(), multiDistinctSum), DEFAULT_TYPE_CAST_RULE); + private void calculateStatistics(OptExpression expr, OptimizerContext context) { + // Avoid repeated calculate + if (expr.getStatistics() != null) { + return; + } + + for (OptExpression child : expr.getInputs()) { + calculateStatistics(child, context); + } + + ExpressionContext expressionContext = new ExpressionContext(expr); + StatisticsCalculator statisticsCalculator = new StatisticsCalculator( + expressionContext, context.getColumnRefFactory(), context); + statisticsCalculator.estimatorStats(); + expr.setStatistics(expressionContext.getStatistics()); } } diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/SplitTwoPhaseAggRule.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/SplitTwoPhaseAggRule.java index 3d13ff3ec8973..1026be51bbb0b 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/SplitTwoPhaseAggRule.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/SplitTwoPhaseAggRule.java @@ -33,6 +33,7 @@ import com.starrocks.sql.optimizer.operator.logical.LogicalAggregationOperator; import com.starrocks.sql.optimizer.operator.scalar.CallOperator; import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator; +import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator; import com.starrocks.sql.optimizer.rule.RuleType; import com.starrocks.sql.optimizer.statistics.ColumnStatistic; import com.starrocks.sql.optimizer.statistics.Statistics; @@ -126,10 +127,10 @@ public List transform(OptExpression input, OptimizerContext conte private boolean isSuitableForTwoStageDistinct(OptExpression input, LogicalAggregationOperator operator, List distinctColumns) { int aggMode = ConnectContext.get().getSessionVariable().getNewPlannerAggStage(); - boolean canTwoStage = canGenerateTwoStageAggregate(operator, distinctColumns); - - if (!canTwoStage) { - return false; + for (CallOperator callOperator : operator.getAggregations().values()) { + if (callOperator.isDistinct() && !canGenerateTwoStageAggregate(callOperator)) { + return false; + } } if (aggMode == TWO_STAGE.ordinal()) { @@ -141,26 +142,25 @@ private boolean isSuitableForTwoStageDistinct(OptExpression input, LogicalAggreg && isTwoStageMoreEfficient(input, distinctColumns); } - private boolean canGenerateTwoStageAggregate(LogicalAggregationOperator operator, - List distinctColumns) { - + private boolean canGenerateTwoStageAggregate(CallOperator distinctCall) { + List distinctCols = distinctCall.getColumnRefs(); + List children = distinctCall.getChildren(); // 1. multiple cols distinct is not support two stage aggregate // 2. array type col is not support two stage aggregate - if (distinctColumns.size() > 1 || distinctColumns.get(0).getType().isArrayType()) { + if (distinctCols.size() > 1 || children.get(0).getType().isComplexType()) { return false; } - // 3. group_concat distinct with columnRef is not support two stage aggregate - // 4. array_agg with order by clause is not support two stage aggregate - for (CallOperator aggCall : operator.getAggregations().values()) { - String fnName = aggCall.getFnName(); - if (FunctionSet.GROUP_CONCAT.equalsIgnoreCase(fnName)) { + // 3. group_concat distinct or avg distinct is not support two stage aggregate + // 4. array_agg with order by clause or decimal distinct col is not support two stage aggregate + String fnName = distinctCall.getFnName(); + if (FunctionSet.GROUP_CONCAT.equalsIgnoreCase(fnName) || FunctionSet.AVG.equalsIgnoreCase(fnName)) { + return false; + } else if (FunctionSet.ARRAY_AGG.equalsIgnoreCase(fnName)) { + AggregateFunction aggregateFunction = (AggregateFunction) distinctCall.getFunction(); + if (CollectionUtils.isNotEmpty(aggregateFunction.getIsAscOrder()) || + children.get(0).getType().isDecimalOfAnyVersion()) { return false; - } else if (FunctionSet.ARRAY_AGG.equalsIgnoreCase(fnName)) { - AggregateFunction aggregateFunction = (AggregateFunction) aggCall.getFunction(); - if (CollectionUtils.isNotEmpty(aggregateFunction.getIsAscOrder())) { - return false; - } } } return true; @@ -199,12 +199,20 @@ private CallOperator rewriteDistinctAggFn(CallOperator fnCall) { return new CallOperator( FunctionSet.MULTI_DISTINCT_SUM, fnCall.getType(), fnCall.getChildren(), multiDistinctSumFn, false); } else if (functionName.equalsIgnoreCase(FunctionSet.ARRAY_AGG)) { - return new CallOperator(FunctionSet.ARRAY_AGG_DISTINCT, fnCall.getType(), fnCall.getChildren(), - Expr.getBuiltinFunction(FunctionSet.ARRAY_AGG_DISTINCT, new Type[] {fnCall.getChild(0).getType()}, - IS_NONSTRICT_SUPERTYPE_OF), false); + if (fnCall.getUsedColumns().isEmpty() && fnCall.getChild(0).getType().isDecimalOfAnyVersion()) { + return fnCall; + } else { + return new CallOperator(FunctionSet.ARRAY_AGG_DISTINCT, fnCall.getType(), fnCall.getChildren(), + Expr.getBuiltinFunction(FunctionSet.ARRAY_AGG_DISTINCT, new Type[] {fnCall.getChild(0).getType()}, + IS_NONSTRICT_SUPERTYPE_OF), false); + } + } else if (functionName.equals(FunctionSet.GROUP_CONCAT)) { - // all children of group_concat is constant + // all children of group_concat are constant return fnCall; + } else if (functionName.equals(FunctionSet.AVG)) { + // all children of avg are constant + return new CallOperator(FunctionSet.AVG, fnCall.getType(), fnCall.getChildren(), fnCall.getFunction(), false); } throw new StarRocksPlannerException(ErrorType.INTERNAL_ERROR, "unsupported distinct agg functions: %s in two phase agg", fnCall); diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/statistics/StatisticsEstimateCoefficient.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/statistics/StatisticsEstimateCoefficient.java index 08c07151a0cfa..4223cfce600f6 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/statistics/StatisticsEstimateCoefficient.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/statistics/StatisticsEstimateCoefficient.java @@ -46,6 +46,8 @@ public class StatisticsEstimateCoefficient { // the aggregate has good effect. public static final double LOW_AGGREGATE_EFFECT_COEFFICIENT = 1000; public static final double MEDIUM_AGGREGATE_EFFECT_COEFFICIENT = 100; + + public static final double EXTREME_HIGH_AGGREGATE_EFFECT_COEFFICIENT = 3; // default selectivity for anti join public static final double DEFAULT_ANTI_JOIN_SELECTIVITY_COEFFICIENT = 0.4; // default shuffle column row count limit diff --git a/fe/fe-core/src/test/java/com/starrocks/analysis/SelectStmtTest.java b/fe/fe-core/src/test/java/com/starrocks/analysis/SelectStmtTest.java index 52a99198e7f02..136f998a559bf 100644 --- a/fe/fe-core/src/test/java/com/starrocks/analysis/SelectStmtTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/analysis/SelectStmtTest.java @@ -418,16 +418,14 @@ void testScalarCorrelatedSubquery() throws Exception { void testMultiDistinctMultiColumnWithLimit(String sql, String pattern) throws Exception { starRocksAssert.getCtx().getSessionVariable().setOptimizerExecuteTimeout(30000000); String plan = UtFrameUtils.getFragmentPlan(starRocksAssert.getCtx(), sql); - System.out.println(plan); Assert.assertTrue(plan, plan.contains(pattern)); } @Test - public void test() throws Exception { + public void testSingleMultiColumnDistinct() throws Exception { starRocksAssert.getCtx().getSessionVariable().setOptimizerExecuteTimeout(30000000); String plan = UtFrameUtils.getFragmentPlan(starRocksAssert.getCtx(), "select count(distinct k1, k2), count(distinct k3) from db1.tbl1 limit 1"); - System.out.println(plan); Assert.assertTrue(plan, plan.contains("18:NESTLOOP JOIN\n" + " | join op: CROSS JOIN\n" + " | colocate: false, reason: \n" + diff --git a/fe/fe-core/src/test/java/com/starrocks/analysis/SelectStmtWithDecimalTypesNewPlannerTest.java b/fe/fe-core/src/test/java/com/starrocks/analysis/SelectStmtWithDecimalTypesNewPlannerTest.java index 1245179cc7e60..93ab32b6117ea 100644 --- a/fe/fe-core/src/test/java/com/starrocks/analysis/SelectStmtWithDecimalTypesNewPlannerTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/analysis/SelectStmtWithDecimalTypesNewPlannerTest.java @@ -610,7 +610,7 @@ public void testSumDistinctWithRewriteMultiDistinctRuleTakeEffect() throws Excep "args: DECIMAL128; result: DECIMAL128(38,3); " + "args nullable: true; result nullable: true]"; String plan = UtFrameUtils.getVerboseFragmentPlan(ctx, sql); - Assert.assertTrue(plan.contains(expectPhase1Snippet) && plan.contains(expectPhase2Snippet)); + Assert.assertTrue(plan, plan.contains(expectPhase1Snippet) && plan.contains(expectPhase2Snippet)); ctx.getSessionVariable().setNewPlanerAggStage(oldStage); ctx.getSessionVariable().setCboCteReuse(oldCboCteReUse); @@ -642,27 +642,27 @@ public void testOnePhaseAvgDistinctDecimal() throws Exception { String sql = "select avg(distinct col_decimal32p9s2) from db1.decimal_table"; String plan = UtFrameUtils.getVerboseFragmentPlan(ctx, sql); - String multiDistinctSumSnippet = "aggregate: multi_distinct_sum[([8: col_decimal32p9s2, DECIMAL32(9,2), false]); " + - "args: DECIMAL32; result: DECIMAL128(38,2)"; - String multiDistinctCountSnippet = "multi_distinct_count[([10: col_decimal32p9s2, DECIMAL32(9,2), false]);" + - " args: DECIMAL32; result: BIGINT"; - Assert.assertTrue(plan, plan.contains(multiDistinctSumSnippet) && plan.contains(multiDistinctCountSnippet)); + String localAggSnippet = "aggregate: avg[([2: col_decimal32p9s2, DECIMAL32(9,2), false]); " + + "args: DECIMAL32; result: VARBINARY; args nullable: false; result nullable: true]"; + String globalAggSnippet = "aggregate: avg[([6: avg, VARBINARY, true]); args: DECIMAL32; " + + "result: DECIMAL128(38,8); args nullable: true; result nullable: true]"; + Assert.assertTrue(plan, plan.contains(localAggSnippet) && plan.contains(globalAggSnippet)); sql = "select avg(distinct col_decimal64p13s0) from db1.decimal_table"; plan = UtFrameUtils.getVerboseFragmentPlan(ctx, sql); - multiDistinctSumSnippet = "multi_distinct_sum[([8: col_decimal64p13s0, DECIMAL64(13,0), false]); " + - "args: DECIMAL64; result: DECIMAL128(38,0)"; - multiDistinctCountSnippet = "multi_distinct_count[([10: col_decimal64p13s0, DECIMAL64(13,0), false]); " + - "args: DECIMAL64; result: BIGINT"; - Assert.assertTrue(plan.contains(multiDistinctSumSnippet) && plan.contains(multiDistinctCountSnippet)); + localAggSnippet = "aggregate: avg[([3: col_decimal64p13s0, DECIMAL64(13,0), false]); " + + "args: DECIMAL64; result: VARBINARY;"; + globalAggSnippet = "aggregate: avg[([6: avg, VARBINARY, true]); args: DECIMAL64; result: DECIMAL128(38,6); " + + "args nullable: true; result nullable: true"; + Assert.assertTrue(plan,plan.contains(localAggSnippet) && plan.contains(globalAggSnippet)); sql = "select avg(distinct col_decimal128p20s3) from db1.decimal_table"; plan = UtFrameUtils.getVerboseFragmentPlan(ctx, sql); - multiDistinctSumSnippet = "multi_distinct_sum[([8: col_decimal128p20s3, DECIMAL128(20,3), true]); " + - "args: DECIMAL128; result: DECIMAL128(38,3)"; - multiDistinctCountSnippet = "multi_distinct_count[([10: col_decimal128p20s3, DECIMAL128(20,3), true]); " + - "args: DECIMAL128; result: BIGINT"; - Assert.assertTrue(plan.contains(multiDistinctSumSnippet) && plan.contains(multiDistinctCountSnippet)); + localAggSnippet = "aggregate: avg[([5: col_decimal128p20s3, DECIMAL128(20,3), true]); args: DECIMAL128; " + + "result: VARBINARY; args nullable: true; result nullable: true"; + globalAggSnippet = "aggregate: avg[([6: avg, VARBINARY, true]); args: DECIMAL128; result: DECIMAL128(38,9);" + + " args nullable: true; result nullable: true]"; + Assert.assertTrue(plan,plan.contains(localAggSnippet) && plan.contains(globalAggSnippet)); ctx.getSessionVariable().setNewPlanerAggStage(oldStage); } @@ -674,28 +674,27 @@ public void testTwoPhaseAvgDistinct() throws Exception { ctx.getSessionVariable().setNewPlanerAggStage(2); String sql = "select avg(distinct col_decimal32p9s2) from db1.decimal_table"; String plan = UtFrameUtils.getVerboseFragmentPlan(ctx, sql); - String expectPhase1Snippet = "multi_distinct_sum[([8: col_decimal32p9s2, DECIMAL32(9,2), false]); " + - "args: DECIMAL32; result: VARBINARY"; - String expectPhase2Snippet = "multi_distinct_sum[([7: sum, VARBINARY, true]); " + - "args: DECIMAL32; result: DECIMAL128(38,2)"; - Assert.assertTrue(plan.contains(expectPhase1Snippet) && plan.contains(expectPhase2Snippet)); + String localAggSnippet = "aggregate: avg[([2: col_decimal32p9s2, DECIMAL32(9,2), false]); " + + "args: DECIMAL32; result: VARBINARY; args nullable: false; result nullable: true]"; + String globalAggSnippet = "aggregate: avg[([6: avg, VARBINARY, true]); args: DECIMAL32; " + + "result: DECIMAL128(38,8); args nullable: true; result nullable: true]"; + Assert.assertTrue(plan, plan.contains(localAggSnippet) && plan.contains(globalAggSnippet)); sql = "select avg(distinct col_decimal64p13s0) from db1.decimal_table"; plan = UtFrameUtils.getVerboseFragmentPlan(ctx, sql); - expectPhase1Snippet = "multi_distinct_sum[([8: col_decimal64p13s0, DECIMAL64(13,0), false]); " + - "args: DECIMAL64; result: VARBINARY"; - expectPhase2Snippet = "multi_distinct_sum[([7: sum, VARBINARY, true]); " + - "args: DECIMAL64; result: DECIMAL128(38,0)"; - Assert.assertTrue(plan.contains(expectPhase1Snippet) && plan.contains(expectPhase2Snippet)); - + localAggSnippet = "aggregate: avg[([3: col_decimal64p13s0, DECIMAL64(13,0), false]); " + + "args: DECIMAL64; result: VARBINARY;"; + globalAggSnippet = "aggregate: avg[([6: avg, VARBINARY, true]); args: DECIMAL64; result: DECIMAL128(38,6); " + + "args nullable: true; result nullable: true"; + Assert.assertTrue(plan,plan.contains(localAggSnippet) && plan.contains(globalAggSnippet)); sql = "select avg(distinct col_decimal128p20s3) from db1.decimal_table"; plan = UtFrameUtils.getVerboseFragmentPlan(ctx, sql); - expectPhase1Snippet = "multi_distinct_sum[([8: col_decimal128p20s3, DECIMAL128(20,3), true]); " + - "args: DECIMAL128; result: VARBINARY"; - expectPhase2Snippet = "multi_distinct_sum[([7: sum, VARBINARY, true]); " + - "args: DECIMAL128; result: DECIMAL128(38,3)"; - Assert.assertTrue(plan.contains(expectPhase1Snippet) && plan.contains(expectPhase2Snippet)); + localAggSnippet = "aggregate: avg[([5: col_decimal128p20s3, DECIMAL128(20,3), true]); args: DECIMAL128; " + + "result: VARBINARY; args nullable: true; result nullable: true"; + globalAggSnippet = "aggregate: avg[([6: avg, VARBINARY, true]); args: DECIMAL128; result: DECIMAL128(38,9);" + + " args nullable: true; result nullable: true]"; + Assert.assertTrue(plan,plan.contains(localAggSnippet) && plan.contains(globalAggSnippet)); ctx.getSessionVariable().setNewPlanerAggStage(oldStage); } @@ -746,7 +745,7 @@ public void testAvgDistinctWithRewriteMultiDistinctRuleTakeEffect() throws Excep expectPhase1Snippet = removeSlotIds(expectPhase1Snippet); expectPhase2Snippet = removeSlotIds(expectPhase2Snippet); projectOutputColumns = removeSlotIds(projectOutputColumns); - Assert.assertTrue(plan.contains(expectPhase1Snippet) && + Assert.assertTrue(plan, plan.contains(expectPhase1Snippet) && plan.contains(expectPhase2Snippet) && plan.contains(projectOutputColumns)); @@ -780,31 +779,16 @@ public void testAvgDistinctWithRewriteMultiDistinctByCTERuleTakeEffect() throws } @Test - public void testAvgDistinctNonDecimalTypeWithRewriteMultiDistinctRuleTakeEffect() throws Exception { + public void testAvgDistinctNonDecimalType() throws Exception { int oldStage = ctx.getSessionVariable().getNewPlannerAggStage(); boolean oldCboCteReUse = ctx.getSessionVariable().isCboCteReuse(); ctx.getSessionVariable().setNewPlanerAggStage(2); String sql = "select avg(distinct key0) from db1.decimal_table"; - String[] snippets = new String[]{ - "cast([multi_distinct_sum, BIGINT, true] as DOUBLE) / " + - "cast([multi_distinct_count, BIGINT, false] as DOUBLE)", - "multi_distinct_count[([multi_distinct_count, VARBINARY, false]); " + - "args: INT; result: BIGINT; args nullable: true; result nullable: false]", - "multi_distinct_sum[([multi_distinct_sum, VARBINARY, true]); " + - "args: INT; result: BIGINT; args nullable: true; result nullable: true]", - "multi_distinct_count[([key0, INT, false]); args: INT; result: VARBINARY; " + - "args nullable: false; result nullable: false]", - "multi_distinct_sum[([key0, INT, false]); " + - "args: INT; result: VARBINARY; args nullable: false; result nullable: true]", - }; - - ctx.getSessionVariable().setCboCteReuse(false); - String disableCtePlan = removeSlotIds(UtFrameUtils.getVerboseFragmentPlan(ctx, sql)); - Assert.assertTrue(Arrays.asList(snippets).stream().anyMatch(s -> disableCtePlan.contains(s))); - - ctx.getSessionVariable().setCboCteReuse(true); - String enableCtePlan = removeSlotIds(UtFrameUtils.getVerboseFragmentPlan(ctx, sql)); - Assert.assertTrue(Arrays.asList(snippets).stream().anyMatch(s -> enableCtePlan.contains(s))); + String plan = UtFrameUtils.getVerboseFragmentPlan(ctx, sql); + Assert.assertTrue(plan, plan.contains("aggregate: avg[([1: key0, INT, false]); args: INT; result: VARBINARY; " + + "args nullable: false; result nullable: true]")); + Assert.assertTrue(plan, plan.contains("aggregate: avg[([6: avg, VARBINARY, true]); args: INT; result: DOUBLE; " + + "args nullable: true; result nullable: true]")); ctx.getSessionVariable().setNewPlanerAggStage(oldStage); ctx.getSessionVariable().setCboCteReuse(oldCboCteReUse); diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/AggregateTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/AggregateTest.java index a462b2c7be674..54b6aac5352ae 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/plan/AggregateTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/AggregateTest.java @@ -1392,30 +1392,37 @@ public void testMultiCountDistinctWithNoneGroup4() throws Exception { public void testMultiAvgDistinctWithNoneGroup() throws Exception { String sql = "select avg(distinct t1b) from test_all_type"; String plan = getFragmentPlan(sql); - assertContains(plan, "19:Project\n" + - " | : CAST(12: sum AS DOUBLE) / CAST(14: count AS DOUBLE)"); + assertContains(plan, "4:AGGREGATE (update serialize)\n" + + " | output: avg(2: t1b)\n" + + " | group by: \n" + + " | \n" + + " 3:AGGREGATE (merge serialize)\n" + + " | group by: 2: t1b"); sql = "select avg(distinct t1b), count(distinct t1b) from test_all_type"; plan = getFragmentPlan(sql); - assertContains(plan, "19:Project\n" + - " | : CAST(14: sum AS DOUBLE) / CAST(12: count AS DOUBLE)\n" + - " | : 12: count"); + assertContains(plan, "4:AGGREGATE (update serialize)\n" + + " | output: avg(2: t1b), count(2: t1b)\n" + + " | group by: \n" + + " | \n" + + " 3:AGGREGATE (merge serialize)\n" + + " | group by: 2: t1b"); sql = "select avg(distinct t1b), count(distinct t1b), sum(distinct t1b) from test_all_type"; plan = getFragmentPlan(sql); - assertContains(plan, "9:Project\n" + - " | : CAST(13: sum AS DOUBLE) / CAST(12: count AS DOUBLE)\n" + - " | : 12: count\n" + - " | : 13: sum"); + assertContains(plan, "4:AGGREGATE (update serialize)\n" + + " | output: avg(2: t1b), count(2: t1b), sum(2: t1b)\n" + + " | group by: \n" + + " | \n" + + " 3:AGGREGATE (merge serialize)\n" + + " | group by: 2: t1b"); sql = "select avg(distinct t1b + 1), count(distinct t1b+1), sum(distinct t1b + 1), count(t1b) from test_all_type"; plan = getFragmentPlan(sql); - assertContains(plan, " 27:Project\n" + - " | : CAST(14: sum AS DOUBLE) / CAST(13: count AS DOUBLE)\n" + - " | : 13: count\n" + - " | : 14: sum\n" + - " | : 15: count"); + assertContains(plan, "7:AGGREGATE (merge finalize)\n" + + " | output: avg(12: avg), count(13: count), sum(14: sum), count(15: count)\n" + + " | group by:"); sql = "select avg(distinct t1b + 1), count(distinct t1b), sum(distinct t1c), count(t1c), sum(t1c) from test_all_type"; @@ -1429,24 +1436,21 @@ public void testMultiAvgDistinctWithNoneGroup() throws Exception { sql = "select avg(distinct 1), count(distinct null), count(distinct 1) from test_all_type"; plan = getFragmentPlan(sql); - assertContains(plan, "16:AGGREGATE (update serialize)\n" + - " | output: multi_distinct_sum(1)\n" + - " | group by: \n" + - " | \n" + - " 15:Project\n" + - " | : 15: auto_fill_col"); + assertContains(plan, "4:AGGREGATE (merge finalize)\n" + + " | output: avg(11: avg), multi_distinct_count(12: count), multi_distinct_count(13: count)"); sql = "select avg(distinct 1), count(distinct null), count(distinct 1), " + "count(distinct (t1a + t1c)), sum(t1c) from test_all_type"; plan = getFragmentPlan(sql); - assertContains(plan, "26:AGGREGATE (update serialize)\n" + - " | output: multi_distinct_sum(1)\n" + + assertContains(plan, "7:AGGREGATE (merge finalize)\n" + + " | output: avg(12: avg), count(13: count), count(14: count), count(15: count), sum(16: sum)"); + assertContains(plan, "5:AGGREGATE (update serialize)\n" + + " | output: avg(1), count(NULL), count(1), count(11: expr), sum(16: sum)\n" + " | group by: \n" + " | \n" + - " 25:Project\n" + - " | : 3: t1c"); - assertContains(plan, "4:AGGREGATE (update serialize)\n" + - " | output: multi_distinct_count(NULL)"); + " 4:AGGREGATE (merge serialize)\n" + + " | output: sum(16: sum)\n" + + " | group by: 11: expr"); } @Test @@ -1648,11 +1652,17 @@ public void testOuterJoinSatisfyAgg() throws Exception { public void testAvgCountDistinctWithHaving() throws Exception { String sql = "select avg(distinct s_suppkey), count(distinct s_acctbal) " + "from supplier having avg(distinct s_suppkey) > 3 ;"; + connectContext.getSessionVariable().setOptimizerExecuteTimeout(-1); String plan = getFragmentPlan(sql); - assertContains(plan, " 28:NESTLOOP JOIN\n" + - " | join op: INNER JOIN\n" + - " | colocate: false, reason: \n" + - " | other join predicates: CAST(12: sum AS DOUBLE) / CAST(14: count AS DOUBLE) > 3.0"); + assertContains(plan, "30:SELECT\n" + + " | predicates: 9: avg > 3.0\n" + + " | \n" + + " 29:Project\n" + + " | : CAST(12: sum AS DOUBLE) / CAST(14: count AS DOUBLE)\n" + + " | : 10: count\n" + + " | \n" + + " 28:NESTLOOP JOIN\n" + + " | join op: CROSS JOIN"); } @Test diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/DecimalTypeTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/DecimalTypeTest.java index ca35dec7d7984..b35f545d6d0f0 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/plan/DecimalTypeTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/DecimalTypeTest.java @@ -220,7 +220,7 @@ public void testArrayAggDecimal() throws Exception { try { String sql = "select array_agg(distinct c_1_6) from tab1 group by c_1_0,c_1_1;"; String plan = getVerboseExplain(sql); - assertContains(plan, "array_agg_distinct"); + assertNotContains(plan, "array_agg_distinct"); } finally { connectContext.getSessionVariable().setNewPlanerAggStage(stage); } diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/DistinctAggTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/DistinctAggTest.java index 129f7823b45bb..c0fe8fb72b89c 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/plan/DistinctAggTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/DistinctAggTest.java @@ -16,6 +16,7 @@ import com.google.common.collect.Lists; import com.starrocks.common.FeConstants; +import org.junit.Test; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; @@ -39,6 +40,35 @@ public void testSqlWithDistinctLimit(String sql, String expectedPlan) throws Exc assertContains(plan, expectedPlan); } + @Test + public void testDistinctConstants() throws Exception { + String sql = "select count(distinct 1, 2, 3, 4), sum(distinct 1), avg(distinct 1), " + + "group_concat(distinct 1, 2 order by 1), array_agg(distinct 1 order by 1) from t0 group by v2;"; + String plan = getFragmentPlan(sql); + assertContains(plan, "4:AGGREGATE (update finalize)\n" + + " | output: count(if(1 IS NULL, NULL, if(2 IS NULL, NULL, if(3 IS NULL, NULL, 4)))), sum(1), avg(1), " + + "group_concat('1', '2', ','), array_agg(1)"); + sql = "select count(distinct 1, 2, 3, 4) from t0 group by v2"; + plan = getFragmentPlan(sql); + assertContains(plan, "3:AGGREGATE (merge finalize)\n" + + " | output: multi_distinct_count(4: count, 1, 2, 3, 4)"); + + sql = "select count(distinct v3, 1) from t0 group by v2"; + plan = getFragmentPlan(sql); + assertContains(plan, "4:AGGREGATE (update finalize)\n" + + " | output: count(if(3: v3 IS NULL, NULL, 1))"); + + sql = "select array_agg(distinct 1.33) from t0"; + plan = getFragmentPlan(sql); + assertContains(plan, "2:AGGREGATE (update serialize)\n" + + " | output: array_agg(DISTINCT 1.33)"); + + sql = "select group_concat(distinct 1.33) from t0"; + plan = getFragmentPlan(sql); + assertContains(plan, "2:AGGREGATE (update serialize)\n" + + " | output: group_concat(DISTINCT '1.33', ',')"); + } + private static Stream sqlWithDistinctLimit() { List argumentsList = Lists.newArrayList(); argumentsList.add(Arguments.of("select count(distinct v1, v2) from (select * from t0 limit 2) t", diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/JoinTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/JoinTest.java index 14c46fcc4c6be..96f43d33cf279 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/plan/JoinTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/JoinTest.java @@ -30,6 +30,7 @@ import org.junit.Test; public class JoinTest extends PlanTestBase { + @Test public void testColocateDistributeSatisfyShuffleColumns() throws Exception { FeConstants.runningUnitTest = true; diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/LowCardinalityTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/LowCardinalityTest.java index f7de44bbc79ff..9a35aa26fdd9a 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/plan/LowCardinalityTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/LowCardinalityTest.java @@ -339,7 +339,7 @@ public void testDecodeNodeRewriteMultiAgg() sql = "select count(distinct S_ADDRESS), count(distinct S_NATIONKEY) from supplier " + "having count(1) > 0"; plan = getVerboseExplain(sql); - Assert.assertFalse(plan, plan.contains("dict_col=")); + Assert.assertTrue(plan, plan.contains("dict_col=S_ADDRESS")); Assert.assertFalse(plan, plan.contains("Decode")); } finally { connectContext.getSessionVariable().setCboCteReuse(cboCteReuse); diff --git a/test/sql/test_agg_function/R/test_bitmap_agg b/test/sql/test_agg_function/R/test_bitmap_agg index c22edd8d5cc0d..2e4470a4ff745 100644 --- a/test/sql/test_agg_function/R/test_bitmap_agg +++ b/test/sql/test_agg_function/R/test_bitmap_agg @@ -239,24 +239,6 @@ select c1, bitmap_union_count(to_bitmap(c4)), BITMAP_TO_STRING(bitmap_agg(c4)) f 5 0 6 0 None -- !result -select c1, count(distinct c2), count(distinct c3), count(distinct c4) from t1 group by c1 order by c1; --- result: -1 1 1 1 -2 1 1 1 -3 1 1 1 -4 0 0 0 -5 1 1 1 -6 0 0 0 --- !result -select c1, multi_distinct_count(c2), multi_distinct_count(c3), multi_distinct_count(c4) from t1 group by c1 order by c1; --- result: -1 1 1 1 -2 1 1 1 -3 1 1 1 -4 0 0 0 -5 1 0 0 -6 0 0 0 --- !result drop materialized view mv2; -- result: -- !result \ No newline at end of file diff --git a/test/sql/test_agg_function/R/test_count_distinct b/test/sql/test_agg_function/R/test_count_distinct index c7abfe66ac155..63abca70e1668 100644 --- a/test/sql/test_agg_function/R/test_count_distinct +++ b/test/sql/test_agg_function/R/test_count_distinct @@ -3,7 +3,10 @@ CREATE TABLE `test_cc` ( `v1` varchar(65533) NULL COMMENT "", `v2` varchar(65533) NULL COMMENT "", `v3` datetime NULL COMMENT "", - `v4` int null + `v4` int null, + `v5` decimal(32, 2) null, + `v6` array null, + `v7` struct NULL ) ENGINE=OLAP DUPLICATE KEY(v1, v2, v3) PARTITION BY RANGE(`v3`) @@ -21,28 +24,28 @@ PROPERTIES ( ); -- result: -- !result -insert into test_cc values('a','a', '2022-04-18 01:01:00', 1); +insert into test_cc values('a','a', '2022-04-18 01:01:00', 1, 1.2, [1, 2, 3], row(1, 'a')); -- result: -- !result -insert into test_cc values('a','b', '2022-04-18 02:01:00', 2); +insert into test_cc values('a','b', '2022-04-18 02:01:00', 2, 1.3, [2, 1, 3], row(2, 'a')); -- result: -- !result -insert into test_cc values('a','a', '2022-04-18 02:05:00', 1); +insert into test_cc values('a','a', '2022-04-18 02:05:00', 1, 2.3, [2, 2, 3], row(3, 'a')); -- result: -- !result -insert into test_cc values('a','b', '2022-04-18 02:15:00', 3); +insert into test_cc values('a','b', '2022-04-18 02:15:00', 3, 3.31, [2, 2, 3], row(4, 'a')); -- result: -- !result -insert into test_cc values('a','b', '2022-04-18 03:15:00', 1); +insert into test_cc values('a','b', '2022-04-18 03:15:00', 1, 100.3, [3, 1, 3], row(2, 'a')); -- result: -- !result -insert into test_cc values('c','a', '2022-04-18 03:45:00', 1); +insert into test_cc values('c','a', '2022-04-18 03:45:00', 1, 200.3, [2, 2, 3], row(3, 'a')); -- result: -- !result -insert into test_cc values('c','a', '2022-04-18 03:25:00', 2); +insert into test_cc values('c','a', '2022-04-18 03:25:00', 2, 300.3, null, row(2, 'a')); -- result: -- !result -insert into test_cc values('c','a', '2022-04-18 03:27:00', 3); +insert into test_cc values('c','a', '2022-04-18 03:27:00', 3, 400.3, [3, 1, 3], null); -- result: -- !result select v2, count(1), count(distinct v1) from test_cc group by v2; @@ -52,11 +55,106 @@ b 3 1 -- !result select v2, bitmap_union_count(to_bitmap(v4)), count(distinct v1) from test_cc group by v2; -- result: -a 3 2 b 3 1 +a 3 2 -- !result select v2, hll_union_agg(hll_hash(v4)), count(distinct v1) from test_cc group by v2; -- result: a 3 2 b 3 1 -- !result +select count(distinct 1, 2, 3, 4) from test_cc; +-- result: +1 +-- !result +select /*+ new_planner_agg_stage = 3 */ count(distinct 1, 2, 3, 4) from test_cc group by v2; +-- result: +1 +1 +-- !result +select count(distinct 1, v2) from test_cc; +-- result: +2 +-- !result +select /*+ new_planner_agg_stage = 2 */ count(distinct 1, v2) from test_cc; +-- result: +2 +-- !result +select count(distinct 1, 2, 3, 4), sum(distinct 1), avg(distinct 1), group_concat(distinct 1, 2 order by 1), array_agg(distinct 1.3 order by null) from test_cc; +-- result: +1 1 1.0 12 [1.3] +-- !result +select count(distinct 1, 2, 3, 4), sum(distinct 1), avg(distinct 1), group_concat(distinct 1, 2 order by 1), array_agg(distinct 1.3 order by null) from test_cc group by v2; +-- result: +1 1 1.0 12 [1.3] +1 1 1.0 12 [1.3] +-- !result +select v2, count(distinct v1), sum(distinct v1), avg(distinct v1), group_concat(distinct 1, 2), array_agg(distinct 1.3 order by null) from test_cc group by v2; +-- result: +b 1 None None 12 [1.3] +a 2 None None 12 [1.3] +-- !result +select v2, count(distinct v4), sum(distinct v4), avg(distinct v4), group_concat(distinct 1, 2), array_agg(distinct 1.3 order by null) from test_cc group by v2; +-- result: +b 3 6 2.0 12 [1.3] +a 3 6 2.0 12 [1.3] +-- !result +select v2, count(distinct v5), sum(distinct v5), avg(distinct v5), group_concat(distinct 1, 2), array_agg(distinct 1.3 order by null) from test_cc group by v2; +-- result: +a 5 904.40 180.88000000 12 [1.3] +b 3 104.91 34.97000000 12 [1.3] +-- !result +select v2, count(distinct v6), array_agg(distinct v6 order by 1), group_concat(distinct 1, 2), array_agg(distinct 1.3 order by null) from test_cc group by v2; +-- result: +a 3 [null,[1,2,3],[2,2,3],[3,1,3]] 12 [1.3] +b 3 [[2,1,3],[2,2,3],[3,1,3]] 12 [1.3] +-- !result +select v2, count(distinct v7), array_agg(distinct v7 order by 1), group_concat(distinct 1, 2 order by 1), array_agg(distinct 1.3 order by null) from test_cc group by v2; +-- result: +E: (1064, "Getting analyzing error from line 1, column 31 to line 1, column 63. Detail message: array_agg can't support order by the 1-th input with type of struct.") +-- !result +select v2, count(distinct v1, v3, v6), sum(distinct v1), avg(distinct v1), array_agg(v5 order by 1), group_concat(distinct 1, 2), array_agg(distinct 1.3 order by null) from test_cc group by v2; +-- result: +a 4 None None [1.20,2.30,200.30,300.30,400.30] 12 [1.3] +b 3 None None [1.30,3.31,100.30] 12 [1.3] +-- !result +select count(distinct v4, v5), sum(distinct v4), avg(distinct v4), group_concat(distinct v4, v5, 2 order by 1,2), array_agg(distinct 1.456 order by 1) from test_cc; +-- result: +8 6 2.0 11.202,12.302,1100.302,1200.302,21.302,2300.302,33.312,3400.302 [1.456] +-- !result +select v2, count(distinct v4, v5), sum(distinct v5), avg(distinct v5), group_concat(distinct v4, v5, 2 order by 1,2), array_agg(distinct 1.456 order by 1) from test_cc group by v2; +-- result: +a 5 904.40 180.88000000 11.202,12.302,1200.302,2300.302,3400.302 [1.456] +b 3 104.91 34.97000000 1100.302,21.302,33.312 [1.456] +-- !result +select v2, count(distinct v3, v5), sum(distinct v4), avg(distinct v5), group_concat(distinct v4, v5, 2 order by 1,2), array_agg(distinct 1.456 order by 1) from test_cc group by v2, v3; +-- result: +a 1 1 200.30000000 1200.302 [1.456] +a 1 3 400.30000000 3400.302 [1.456] +b 1 2 1.30000000 21.302 [1.456] +b 1 1 100.30000000 1100.302 [1.456] +a 1 1 2.30000000 12.302 [1.456] +b 1 3 3.31000000 33.312 [1.456] +a 1 1 1.20000000 11.202 [1.456] +a 1 2 300.30000000 2300.302 [1.456] +-- !result +select count(distinct v4, v5), sum(distinct v4), avg(distinct v4), group_concat(distinct v4, v5, 2 order by 1,2), array_agg(distinct v4 order by 1) from test_cc; +-- result: +8 6 2.0 11.202,12.302,1100.302,1200.302,21.302,2300.302,33.312,3400.302 [1,2,3] +-- !result +select v2, count(distinct v4, v5), sum(distinct v5), avg(distinct v5), group_concat(distinct v4, v5, 2 order by 1,2), array_agg(distinct v5 order by 1) from test_cc group by v2; +-- result: +b 3 104.91 34.97000000 1100.302,21.302,33.312 [1.30,3.31,100.30] +a 5 904.40 180.88000000 11.202,12.302,1200.302,2300.302,3400.302 [1.20,2.30,200.30,300.30,400.30] +-- !result +select v2, count(distinct v3, v5), sum(distinct v4), avg(distinct v5), group_concat(distinct v4, v5, 2 order by 1,2), array_agg(distinct v5 order by 1) from test_cc group by v2, v3; +-- result: +b 1 2 1.30000000 21.302 [1.30] +a 1 2 300.30000000 2300.302 [300.30] +b 1 1 100.30000000 1100.302 [100.30] +a 1 3 400.30000000 3400.302 [400.30] +a 1 1 1.20000000 11.202 [1.20] +a 1 1 2.30000000 12.302 [2.30] +a 1 1 200.30000000 1200.302 [200.30] +b 1 3 3.31000000 33.312 [3.31] +-- !result \ No newline at end of file diff --git a/test/sql/test_agg_function/T/test_bitmap_agg b/test/sql/test_agg_function/T/test_bitmap_agg index d2b9e0963166e..61cf792a1320e 100644 --- a/test/sql/test_agg_function/T/test_bitmap_agg +++ b/test/sql/test_agg_function/T/test_bitmap_agg @@ -72,7 +72,9 @@ SELECT BITMAP_TO_STRING(BITMAP_UNION(TO_BITMAP(c4))) FROM t1; select BITMAP_TO_STRING(bitmap_union(to_bitmap(c2))), BITMAP_TO_STRING(bitmap_union(to_bitmap(c3))), BITMAP_TO_STRING(bitmap_agg(c4)) from t1 group by c1 order by c1; select c1, count(distinct c2), bitmap_union(to_bitmap(c3)), bitmap_agg(c4) from t1 group by c1 order by c1; select c1, bitmap_union_count(to_bitmap(c4)), BITMAP_TO_STRING(bitmap_agg(c4)) from t1 group by c1 order by c1; -select c1, count(distinct c2), count(distinct c3), count(distinct c4) from t1 group by c1 order by c1; -select c1, multi_distinct_count(c2), multi_distinct_count(c3), multi_distinct_count(c4) from t1 group by c1 order by c1; + +-- the below two sqls have different result. +-- select c1, count(distinct c2), count(distinct c3), count(distinct c4) from t1 group by c1 order by c1; +-- select c1, multi_distinct_count(c2), multi_distinct_count(c3), multi_distinct_count(c4) from t1 group by c1 order by c1; drop materialized view mv2; diff --git a/test/sql/test_agg_function/T/test_count_distinct b/test/sql/test_agg_function/T/test_count_distinct index 3f921b3c68532..6a7211aa9222b 100644 --- a/test/sql/test_agg_function/T/test_count_distinct +++ b/test/sql/test_agg_function/T/test_count_distinct @@ -3,7 +3,10 @@ CREATE TABLE `test_cc` ( `v1` varchar(65533) NULL COMMENT "", `v2` varchar(65533) NULL COMMENT "", `v3` datetime NULL COMMENT "", - `v4` int null + `v4` int null, + `v5` decimal(32, 2) null, + `v6` array null, + `v7` struct NULL ) ENGINE=OLAP DUPLICATE KEY(v1, v2, v3) PARTITION BY RANGE(`v3`) @@ -21,15 +24,41 @@ PROPERTIES ( ); -insert into test_cc values('a','a', '2022-04-18 01:01:00', 1); -insert into test_cc values('a','b', '2022-04-18 02:01:00', 2); -insert into test_cc values('a','a', '2022-04-18 02:05:00', 1); -insert into test_cc values('a','b', '2022-04-18 02:15:00', 3); -insert into test_cc values('a','b', '2022-04-18 03:15:00', 1); -insert into test_cc values('c','a', '2022-04-18 03:45:00', 1); -insert into test_cc values('c','a', '2022-04-18 03:25:00', 2); -insert into test_cc values('c','a', '2022-04-18 03:27:00', 3); +insert into test_cc values('a','a', '2022-04-18 01:01:00', 1, 1.2, [1, 2, 3], row(1, 'a')); +insert into test_cc values('a','b', '2022-04-18 02:01:00', 2, 1.3, [2, 1, 3], row(2, 'a')); +insert into test_cc values('a','a', '2022-04-18 02:05:00', 1, 2.3, [2, 2, 3], row(3, 'a')); +insert into test_cc values('a','b', '2022-04-18 02:15:00', 3, 3.31, [2, 2, 3], row(4, 'a')); +insert into test_cc values('a','b', '2022-04-18 03:15:00', 1, 100.3, [3, 1, 3], row(2, 'a')); +insert into test_cc values('c','a', '2022-04-18 03:45:00', 1, 200.3, [2, 2, 3], row(3, 'a')); +insert into test_cc values('c','a', '2022-04-18 03:25:00', 2, 300.3, null, row(2, 'a')); +insert into test_cc values('c','a', '2022-04-18 03:27:00', 3, 400.3, [3, 1, 3], null); select v2, count(1), count(distinct v1) from test_cc group by v2; select v2, bitmap_union_count(to_bitmap(v4)), count(distinct v1) from test_cc group by v2; select v2, hll_union_agg(hll_hash(v4)), count(distinct v1) from test_cc group by v2; + +select count(distinct 1, 2, 3, 4) from test_cc; +select /*+ new_planner_agg_stage = 3 */ count(distinct 1, 2, 3, 4) from test_cc group by v2; +select count(distinct 1, v2) from test_cc; +select /*+ new_planner_agg_stage = 2 */ count(distinct 1, v2) from test_cc; +select count(distinct 1, 2, 3, 4), sum(distinct 1), avg(distinct 1), group_concat(distinct 1, 2 order by 1), array_agg(distinct 1.3 order by null) from test_cc; +select count(distinct 1, 2, 3, 4), sum(distinct 1), avg(distinct 1), group_concat(distinct 1, 2 order by 1), array_agg(distinct 1.3 order by null) from test_cc group by v2; + + +select v2, count(distinct v1), sum(distinct v1), avg(distinct v1), group_concat(distinct 1, 2), array_agg(distinct 1.3 order by null) from test_cc group by v2; +select v2, count(distinct v4), sum(distinct v4), avg(distinct v4), group_concat(distinct 1, 2), array_agg(distinct 1.3 order by null) from test_cc group by v2; +select v2, count(distinct v5), sum(distinct v5), avg(distinct v5), group_concat(distinct 1, 2), array_agg(distinct 1.3 order by null) from test_cc group by v2; +select v2, count(distinct v6), array_agg(distinct v6 order by 1), group_concat(distinct 1, 2), array_agg(distinct 1.3 order by null) from test_cc group by v2; +select v2, count(distinct v7), array_agg(distinct v7 order by 1), group_concat(distinct 1, 2 order by 1), array_agg(distinct 1.3 order by null) from test_cc group by v2; + +select v2, count(distinct v1, v3, v6), sum(distinct v1), avg(distinct v1), array_agg(v5 order by 1), group_concat(distinct 1, 2), array_agg(distinct 1.3 order by null) from test_cc group by v2; +select count(distinct v4, v5), sum(distinct v4), avg(distinct v4), group_concat(distinct v4, v5, 2 order by 1,2), array_agg(distinct 1.456 order by 1) from test_cc; +select v2, count(distinct v4, v5), sum(distinct v5), avg(distinct v5), group_concat(distinct v4, v5, 2 order by 1,2), array_agg(distinct 1.456 order by 1) from test_cc group by v2; +select v2, count(distinct v3, v5), sum(distinct v4), avg(distinct v5), group_concat(distinct v4, v5, 2 order by 1,2), array_agg(distinct 1.456 order by 1) from test_cc group by v2, v3; + +select count(distinct v4, v5), sum(distinct v4), avg(distinct v4), group_concat(distinct v4, v5, 2 order by 1,2), array_agg(distinct v4 order by 1) from test_cc; +select v2, count(distinct v4, v5), sum(distinct v5), avg(distinct v5), group_concat(distinct v4, v5, 2 order by 1,2), array_agg(distinct v5 order by 1) from test_cc group by v2; +select v2, count(distinct v3, v5), sum(distinct v4), avg(distinct v5), group_concat(distinct v4, v5, 2 order by 1,2), array_agg(distinct v5 order by 1) from test_cc group by v2, v3; + + +