diff --git a/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java b/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java index 191a26086eee1c..6c728fb2cdabb8 100644 --- a/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java +++ b/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java @@ -723,6 +723,8 @@ public static MaterializedViewRewriteMode parse(String str) { public static final String ENABLE_CONNECTOR_SINK_WRITER_SCALING = "enable_connector_sink_writer_scaling"; + public static final String ENABLE_CONSTANT_EXECUTE_IN_FE = "enable_constant_execute_in_fe"; + public static final List DEPRECATED_VARIABLES = ImmutableList.builder() .add(CODEGEN_LEVEL) .add(MAX_EXECUTION_TIME) @@ -2039,6 +2041,9 @@ public Optional isFollowerForwardToLeaderOpt() { @VarAttr(name = ENABLE_PIPELINE_LEVEL_SHUFFLE, flag = VariableMgr.INVISIBLE) private boolean enablePipelineLevelShuffle = true; + @VarAttr(name = ENABLE_CONSTANT_EXECUTE_IN_FE) + private boolean enableConstantExecuteInFE = true; + public int getExprChildrenLimit() { return exprChildrenLimit; } @@ -3693,6 +3698,14 @@ public void setEnablePredicateMoveAround(boolean enablePredicateMoveAround) { this.enablePredicateMoveAround = enablePredicateMoveAround; } + public boolean isEnableConstantExecuteInFE() { + return enableConstantExecuteInFE; + } + + public void setEnableConstantExecuteInFE(boolean enableConstantExecuteInFE) { + this.enableConstantExecuteInFE = enableConstantExecuteInFE; + } + // Serialize to thrift object // used for rest api public TQueryOptions toThrift() { diff --git a/fe/fe-core/src/main/java/com/starrocks/qe/StmtExecutor.java b/fe/fe-core/src/main/java/com/starrocks/qe/StmtExecutor.java index 269fb7e79d469a..9651e0c3643171 100644 --- a/fe/fe-core/src/main/java/com/starrocks/qe/StmtExecutor.java +++ b/fe/fe-core/src/main/java/com/starrocks/qe/StmtExecutor.java @@ -111,6 +111,7 @@ import com.starrocks.proto.QueryStatisticsItemPB; import com.starrocks.qe.QueryState.MysqlStateType; import com.starrocks.qe.scheduler.Coordinator; +import com.starrocks.qe.scheduler.FeExecuteCoordinator; import com.starrocks.server.GlobalStateMgr; import com.starrocks.sql.ExplainAnalyzer; import com.starrocks.sql.StatementPlanner; @@ -165,7 +166,10 @@ import com.starrocks.sql.common.ErrorType; import com.starrocks.sql.common.MetaUtils; import com.starrocks.sql.common.StarRocksPlannerException; +import com.starrocks.sql.optimizer.OptExpression; import com.starrocks.sql.optimizer.dump.QueryDumpInfo; +import com.starrocks.sql.optimizer.operator.physical.PhysicalValuesOperator; +import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator; import com.starrocks.sql.plan.ExecPlan; import com.starrocks.statistic.AnalyzeJob; import com.starrocks.statistic.AnalyzeMgr; @@ -197,6 +201,7 @@ import com.starrocks.transaction.TransactionState; import com.starrocks.transaction.TransactionStatus; import com.starrocks.transaction.VisibleStateWaiter; +import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.lang.exception.ExceptionUtils; import org.apache.commons.lang3.StringUtils; import org.apache.logging.log4j.LogManager; @@ -991,6 +996,7 @@ private void handleQueryStmt(ExecPlan execPlan) throws Exception { && StatementBase.ExplainLevel.ANALYZE.equals(parsedStmt.getExplainLevel()); boolean isSchedulerExplain = parsedStmt.isExplain() && StatementBase.ExplainLevel.SCHEDULER.equals(parsedStmt.getExplainLevel()); + boolean executeInFe = !isExplainAnalyze & !isSchedulerExplain & canExecuteInFe(context, execPlan.getPhysicalPlan()); if (isExplainAnalyze) { context.getSessionVariable().setEnableProfile(true); @@ -999,7 +1005,11 @@ private void handleQueryStmt(ExecPlan execPlan) throws Exception { } else if (isSchedulerExplain) { // Do nothing. } else if (parsedStmt.isExplain()) { - handleExplainStmt(buildExplainString(execPlan, ResourceGroupClassifier.QueryType.SELECT)); + String explainString = buildExplainString(execPlan, ResourceGroupClassifier.QueryType.SELECT); + if (executeInFe) { + explainString = "EXECUTE IN FE\n" + explainString; + } + handleExplainStmt(explainString); return; } if (context.getQueryDetail() != null) { @@ -1013,7 +1023,11 @@ private void handleQueryStmt(ExecPlan execPlan) throws Exception { List colNames = execPlan.getColNames(); List outputExprs = execPlan.getOutputExprs(); - coord = getCoordinatorFactory().createQueryScheduler(context, fragments, scanNodes, descTable); + if (executeInFe) { + coord = new FeExecuteCoordinator(context, execPlan); + } else { + coord = getCoordinatorFactory().createQueryScheduler(context, fragments, scanNodes, descTable); + } QeProcessorImpl.INSTANCE.registerQuery(context.getExecutionId(), new QeProcessorImpl.QueryInfo(context, originStmt.originStmt, coord)); @@ -2338,4 +2352,34 @@ public Pair, Status> executeStmtWithExecPlan(ConnectContext c public List getProxyResultBuffer() { return proxyResultBuffer; } + + + // scenes can execute in FE should meet all these requirements: + // 1. enable_constant_execute_in_fe = true + // 2. is mysql text protocol + // 3. all values are constantOperator + private boolean canExecuteInFe(ConnectContext context, OptExpression optExpression) { + if (!context.getSessionVariable().isEnableConstantExecuteInFE()) { + return false; + } + + if (context instanceof HttpConnectContext || context.getCommand() == MysqlCommand.COM_STMT_EXECUTE) { + return false; + } + + if (optExpression.getOp() instanceof PhysicalValuesOperator) { + PhysicalValuesOperator valuesOperator = (PhysicalValuesOperator) optExpression.getOp(); + boolean isAllConstants = true; + if (valuesOperator.getProjection() != null) { + isAllConstants = valuesOperator.getProjection().getColumnRefMap().values().stream() + .allMatch(ScalarOperator::isConstantRef); + } else if (CollectionUtils.isNotEmpty(valuesOperator.getRows())) { + isAllConstants = valuesOperator.getRows().stream().allMatch(row -> + row.stream().allMatch(ScalarOperator::isConstantRef)); + } + + return isAllConstants; + } + return false; + } } diff --git a/fe/fe-core/src/main/java/com/starrocks/qe/scheduler/FeExecuteCoordinator.java b/fe/fe-core/src/main/java/com/starrocks/qe/scheduler/FeExecuteCoordinator.java new file mode 100644 index 00000000000000..a9a5e952941c42 --- /dev/null +++ b/fe/fe-core/src/main/java/com/starrocks/qe/scheduler/FeExecuteCoordinator.java @@ -0,0 +1,414 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.starrocks.qe.scheduler; + +import com.google.common.collect.Lists; +import com.starrocks.analysis.Expr; +import com.starrocks.analysis.SlotRef; +import com.starrocks.catalog.ScalarType; +import com.starrocks.common.Status; +import com.starrocks.common.util.DateUtils; +import com.starrocks.common.util.RuntimeProfile; +import com.starrocks.datacache.DataCacheSelectMetrics; +import com.starrocks.mysql.MysqlSerializer; +import com.starrocks.planner.ScanNode; +import com.starrocks.proto.PPlanFragmentCancelReason; +import com.starrocks.proto.PQueryStatistics; +import com.starrocks.qe.ConnectContext; +import com.starrocks.qe.QueryStatisticsItem; +import com.starrocks.qe.RowBatch; +import com.starrocks.qe.scheduler.slot.LogicalSlot; +import com.starrocks.sql.common.RyuDouble; +import com.starrocks.sql.common.RyuFloat; +import com.starrocks.sql.optimizer.operator.physical.PhysicalValuesOperator; +import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator; +import com.starrocks.sql.optimizer.operator.scalar.ConstantOperator; +import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator; +import com.starrocks.sql.plan.ExecPlan; +import com.starrocks.thrift.TLoadJobType; +import com.starrocks.thrift.TNetworkAddress; +import com.starrocks.thrift.TReportAuditStatisticsParams; +import com.starrocks.thrift.TReportExecStatusParams; +import com.starrocks.thrift.TResultBatch; +import com.starrocks.thrift.TSinkCommitInfo; +import com.starrocks.thrift.TTabletCommitInfo; +import com.starrocks.thrift.TTabletFailInfo; +import com.starrocks.thrift.TUniqueId; +import org.apache.commons.lang3.StringUtils; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.text.DecimalFormat; +import java.time.LocalDateTime; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Supplier; + +public class FeExecuteCoordinator extends Coordinator { + + private final ConnectContext connectContext; + + private final ExecPlan execPlan; + + + public FeExecuteCoordinator(ConnectContext context, ExecPlan execPlan) { + this.connectContext = context; + this.execPlan = execPlan; + } + @Override + public void startScheduling(boolean needDeploy) throws Exception { + + } + + @Override + public String getSchedulerExplain() { + return "FE EXECUTION"; + } + + @Override + public void updateFragmentExecStatus(TReportExecStatusParams params) { + + } + + @Override + public void updateAuditStatistics(TReportAuditStatisticsParams params) { + + } + + @Override + public void cancel(PPlanFragmentCancelReason reason, String message) { + + } + + @Override + public void onFinished() { + + } + + @Override + public LogicalSlot getSlot() { + return null; + } + + @Override + public RowBatch getNext() throws Exception { + RowBatch rowBatch = new RowBatch(); + TResultBatch resultBatch = new TResultBatch(); + resultBatch.rows = covertToMySQLRowBuffer(); + PQueryStatistics statistics = new PQueryStatistics(); + statistics.returnedRows = Long.valueOf(resultBatch.rows.size()); + rowBatch.setBatch(resultBatch); + rowBatch.setQueryStatistics(statistics); + return rowBatch; + } + + @Override + public boolean join(int timeoutSecond) { + return false; + } + + @Override + public boolean checkBackendState() { + return false; + } + + @Override + public boolean isThriftServerHighLoad() { + return false; + } + + @Override + public void setLoadJobType(TLoadJobType type) { + + } + + @Override + public long getLoadJobId() { + return 0; + } + + @Override + public void setLoadJobId(Long jobId) { + + } + + @Override + public Map getChannelIdToBEHTTPMap() { + return null; + } + + @Override + public Map getChannelIdToBEPortMap() { + return null; + } + + @Override + public boolean isEnableLoadProfile() { + return false; + } + + @Override + public void clearExportStatus() { + + } + + @Override + public void collectProfileSync() { + + } + + @Override + public boolean tryProcessProfileAsync(Consumer task) { + return false; + } + + @Override + public void setTopProfileSupplier(Supplier topProfileSupplier) { + + } + + @Override + public void setExecPlan(ExecPlan execPlan) { + + } + + @Override + public RuntimeProfile buildQueryProfile(boolean needMerge) { + return null; + } + + @Override + public RuntimeProfile getQueryProfile() { + return null; + } + + @Override + public List getDeltaUrls() { + return null; + } + + @Override + public Map getLoadCounters() { + return null; + } + + @Override + public List getFailInfos() { + return null; + } + + @Override + public List getCommitInfos() { + return null; + } + + @Override + public List getSinkCommitInfos() { + return null; + } + + @Override + public List getExportFiles() { + return null; + } + + @Override + public String getTrackingUrl() { + return null; + } + + @Override + public List getRejectedRecordPaths() { + return null; + } + + @Override + public List getFragmentInstanceInfos() { + return null; + } + + @Override + public DataCacheSelectMetrics getDataCacheSelectMetrics() { + return null; + } + + @Override + public PQueryStatistics getAuditStatistics() { + return null; + } + + @Override + public Status getExecStatus() { + return null; + } + + @Override + public boolean isUsingBackend(Long backendID) { + return false; + } + + @Override + public boolean isDone() { + return false; + } + + @Override + public TUniqueId getQueryId() { + return null; + } + + @Override + public void setQueryId(TUniqueId queryId) { + + } + + @Override + public List getScanNodes() { + return null; + } + + @Override + public long getStartTimeMs() { + return 0; + } + + @Override + public void setTimeoutSecond(int timeoutSecond) { + + } + + @Override + public boolean isProfileAlreadyReported() { + return false; + } + + private List covertToMySQLRowBuffer() { + MysqlSerializer serializer = MysqlSerializer.newInstance(); + PhysicalValuesOperator valuesOperator = (PhysicalValuesOperator) execPlan.getPhysicalPlan().getOp(); + List res = Lists.newArrayList(); + for (List row : valuesOperator.getRows()) { + serializer.reset(); + if (valuesOperator.getProjection() != null) { + List alignedOutput = Lists.newArrayList(); + for (Expr expr : execPlan.getOutputExprs()) { + SlotRef slotRef = (SlotRef) expr; + for (Map.Entry entry : valuesOperator.getProjection() + .getColumnRefMap().entrySet()) { + if (slotRef.getSlotId().asInt() == entry.getKey().getId()) { + alignedOutput.add(entry.getValue()); + break; + } + } + } + row = alignedOutput; + } + + for (ScalarOperator scalarOperator : row) { + ConstantOperator constantOperator = (ConstantOperator) scalarOperator; + if (constantOperator.isNull()) { + serializer.writeNull(); + } else if (constantOperator.isTrue()) { + serializer.writeLenEncodedString("1"); + } else if (constantOperator.isFalse()) { + serializer.writeLenEncodedString("0"); + } else if (constantOperator.getType().getPrimitiveType().isBinaryType()) { + serializer.writeVInt(constantOperator.getBinary().length); + serializer.writeBytes(constantOperator.getBinary()); + } else { + String value; + switch (constantOperator.getType().getPrimitiveType()) { + case TINYINT: + value = String.valueOf(constantOperator.getTinyInt()); + break; + case SMALLINT: + value = String.valueOf(constantOperator.getSmallint()); + break; + case INT: + value = String.valueOf(constantOperator.getInt()); + break; + case BIGINT: + value = String.valueOf(constantOperator.getBigint()); + break; + case LARGEINT: + value = String.valueOf(constantOperator.getLargeInt()); + break; + case FLOAT: + value = RyuFloat.floatToString((float) constantOperator.getFloat()); + break; + case DOUBLE: + value = RyuDouble.doubleToString(constantOperator.getDouble()); + break; + case DECIMALV2: + value = constantOperator.getDecimal().toPlainString(); + break; + case DECIMAL32: + case DECIMAL64: + case DECIMAL128: + int scale = ((ScalarType) constantOperator.getType()).getScalarScale(); + BigDecimal val1 = constantOperator.getDecimal(); + DecimalFormat df = new DecimalFormat((scale == 0 ? "0" : "0.") + StringUtils.repeat("0", scale)); + value = df.format(val1); + break; + case CHAR: + value = constantOperator.getChar(); + break; + case VARCHAR: + value = constantOperator.getVarchar(); + break; + case TIME: + value = convertToTimeString(constantOperator.getTime()); + break; + case DATE: + LocalDateTime date = constantOperator.getDate(); + value = date.format(DateUtils.DATE_FORMATTER_UNIX); + break; + case DATETIME: + LocalDateTime datetime = constantOperator.getDate(); + if (datetime.getNano() != 0) { + value = datetime.format(DateUtils.DATE_TIME_MS_FORMATTER_UNIX); + } else { + value = datetime.format(DateUtils.DATE_TIME_FORMATTER_UNIX); + } + break; + default: + value = constantOperator.toString(); + } + serializer.writeLenEncodedString(value); + } + } + res.add(serializer.toByteBuffer()); + } + return res; + } + + private String convertToTimeString(double time) { + StringBuilder sb = new StringBuilder(); + if (time < 0) { + sb.append("-"); + time = Math.abs(time); + } + + int day = (int) (time / 86400); + time = time % 86400; + int hour = (int) (time / 3600); + time = time % 3600; + int minute = (int) (time / 60); + time = time % 60; + int second = (int) time; + sb.append(String.format("%02d:%02d:%02d", hour + day * 24, minute, second)); + return sb.toString(); + } +} diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/common/RoundingMode.java b/fe/fe-core/src/main/java/com/starrocks/sql/common/RoundingMode.java new file mode 100644 index 00000000000000..51541aae2ac130 --- /dev/null +++ b/fe/fe-core/src/main/java/com/starrocks/sql/common/RoundingMode.java @@ -0,0 +1,46 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.starrocks.sql.common; + +/** + * Ref: https://github.com/ulfjack/ryu + */ +public enum RoundingMode { + CONSERVATIVE { + @Override + public boolean acceptUpperBound(boolean even) { + return false; + } + + @Override + public boolean acceptLowerBound(boolean even) { + return false; + } + }, + ROUND_EVEN { + @Override + public boolean acceptUpperBound(boolean even) { + return even; + } + + @Override + public boolean acceptLowerBound(boolean even) { + return even; + } + }; + + public abstract boolean acceptUpperBound(boolean even); + public abstract boolean acceptLowerBound(boolean even); +} \ No newline at end of file diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/common/RyuDouble.java b/fe/fe-core/src/main/java/com/starrocks/sql/common/RyuDouble.java new file mode 100644 index 00000000000000..6cf089601bb970 --- /dev/null +++ b/fe/fe-core/src/main/java/com/starrocks/sql/common/RyuDouble.java @@ -0,0 +1,550 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.starrocks.sql.common; + +import java.math.BigInteger; + +/** + * Ref: https://github.com/ulfjack/ryu + * An implementation of Ryu for double. + */ +public final class RyuDouble { + private static boolean DEBUG = false; + + private static final int DOUBLE_MANTISSA_BITS = 52; + private static final long DOUBLE_MANTISSA_MASK = (1L << DOUBLE_MANTISSA_BITS) - 1; + + private static final int DOUBLE_EXPONENT_BITS = 11; + private static final int DOUBLE_EXPONENT_MASK = (1 << DOUBLE_EXPONENT_BITS) - 1; + private static final int DOUBLE_EXPONENT_BIAS = (1 << (DOUBLE_EXPONENT_BITS - 1)) - 1; + + private static final int POS_TABLE_SIZE = 326; + private static final int NEG_TABLE_SIZE = 291; + + // Only for debugging. + private static final BigInteger[] POW5 = new BigInteger[POS_TABLE_SIZE]; + private static final BigInteger[] POW5_INV = new BigInteger[NEG_TABLE_SIZE]; + + private static final int POW5_BITCOUNT = 121; // max 3*31 = 124 + private static final int POW5_QUARTER_BITCOUNT = 31; + private static final int[][] POW5_SPLIT = new int[POS_TABLE_SIZE][4]; + + private static final int POW5_INV_BITCOUNT = 122; // max 3*31 = 124 + private static final int POW5_INV_QUARTER_BITCOUNT = 31; + private static final int[][] POW5_INV_SPLIT = new int[NEG_TABLE_SIZE][4]; + + static { + BigInteger mask = BigInteger.valueOf(1).shiftLeft(POW5_QUARTER_BITCOUNT).subtract(BigInteger.ONE); + BigInteger invMask = BigInteger.valueOf(1).shiftLeft(POW5_INV_QUARTER_BITCOUNT).subtract(BigInteger.ONE); + for (int i = 0; i < Math.max(POW5.length, POW5_INV.length); i++) { + BigInteger pow = BigInteger.valueOf(5).pow(i); + int pow5len = pow.bitLength(); + int expectedPow5Bits = pow5bits(i); + if (expectedPow5Bits != pow5len) { + throw new IllegalStateException(pow5len + " != " + expectedPow5Bits); + } + if (i < POW5.length) { + POW5[i] = pow; + } + if (i < POW5_SPLIT.length) { + for (int j = 0; j < 4; j++) { + POW5_SPLIT[i][j] = pow + .shiftRight(pow5len - POW5_BITCOUNT + (3 - j) * POW5_QUARTER_BITCOUNT) + .and(mask) + .intValueExact(); + } + } + + if (i < POW5_INV_SPLIT.length) { + // We want floor(log_2 5^q) here, which is pow5len - 1. + int j = pow5len - 1 + POW5_INV_BITCOUNT; + BigInteger inv = BigInteger.ONE.shiftLeft(j).divide(pow).add(BigInteger.ONE); + POW5_INV[i] = inv; + for (int k = 0; k < 4; k++) { + if (k == 0) { + POW5_INV_SPLIT[i][k] = inv.shiftRight((3 - k) * POW5_INV_QUARTER_BITCOUNT).intValueExact(); + } else { + POW5_INV_SPLIT[i][k] = + inv.shiftRight((3 - k) * POW5_INV_QUARTER_BITCOUNT).and(invMask).intValueExact(); + } + } + } + } + } + + public static String doubleToString(double value) { + return doubleToString(value, RoundingMode.ROUND_EVEN); + } + + public static String doubleToString(double value, RoundingMode roundingMode) { + // Step 1: Decode the floating point number, and unify normalized and subnormal cases. + // First, handle all the trivial cases. + if (Double.isNaN(value)) { + return "NaN"; + } + if (value == Double.POSITIVE_INFINITY) { + return "Infinity"; + } + if (value == Double.NEGATIVE_INFINITY) { + return "-Infinity"; + } + long bits = Double.doubleToLongBits(value); + if (bits == 0) { + return "0.0"; + } + if (bits == 0x8000000000000000L) { + return "-0.0"; + } + + // Otherwise extract the mantissa and exponent bits and run the full algorithm. + int ieeeExponent = (int) ((bits >>> DOUBLE_MANTISSA_BITS) & DOUBLE_EXPONENT_MASK); + long ieeeMantissa = bits & DOUBLE_MANTISSA_MASK; + int e2; + long m2; + if (ieeeExponent == 0) { + // Denormal number - no implicit leading 1, and the exponent is 1, not 0. + e2 = 1 - DOUBLE_EXPONENT_BIAS - DOUBLE_MANTISSA_BITS; + m2 = ieeeMantissa; + } else { + // Add implicit leading 1. + e2 = ieeeExponent - DOUBLE_EXPONENT_BIAS - DOUBLE_MANTISSA_BITS; + m2 = ieeeMantissa | (1L << DOUBLE_MANTISSA_BITS); + } + + boolean sign = bits < 0; + if (DEBUG) { + System.out.println("IN=" + Long.toBinaryString(bits)); + System.out.println(" S=" + (sign ? "-" : "+") + " E=" + e2 + " M=" + m2); + } + + // Step 2: Determine the interval of legal decimal representations. + boolean even = (m2 & 1) == 0; + final long mv = 4 * m2; + final long mp = 4 * m2 + 2; + final int mmShift = ((m2 != (1L << DOUBLE_MANTISSA_BITS)) || (ieeeExponent <= 1)) ? 1 : 0; + final long mm = 4 * m2 - 1 - mmShift; + e2 -= 2; + + if (DEBUG) { + String sv; + String sp; + String sm; + int e10; + if (e2 >= 0) { + sv = BigInteger.valueOf(mv).shiftLeft(e2).toString(); + sp = BigInteger.valueOf(mp).shiftLeft(e2).toString(); + sm = BigInteger.valueOf(mm).shiftLeft(e2).toString(); + e10 = 0; + } else { + BigInteger factor = BigInteger.valueOf(5).pow(-e2); + sv = BigInteger.valueOf(mv).multiply(factor).toString(); + sp = BigInteger.valueOf(mp).multiply(factor).toString(); + sm = BigInteger.valueOf(mm).multiply(factor).toString(); + e10 = e2; + } + + e10 += sp.length() - 1; + + System.out.println("E =" + e10); + System.out.println("d+=" + sp); + System.out.println("d =" + sv); + System.out.println("d-=" + sm); + System.out.println("e2=" + e2); + } + + // Step 3: Convert to a decimal power base using 128-bit arithmetic. + // -1077 = 1 - 1023 - 53 - 2 <= e_2 - 2 <= 2046 - 1023 - 53 - 2 = 968 + long dv; + long dp; + long dm; + final int e10; + boolean dmIsTrailingZeros = false; + boolean dvIsTrailingZeros = false; + if (e2 >= 0) { + final int q = Math.max(0, ((e2 * 78913) >>> 18) - 1); + // k = constant + floor(log_2(5^q)) + final int k = POW5_INV_BITCOUNT + pow5bits(q) - 1; + final int i = -e2 + q + k; + dv = mulPow5InvDivPow2(mv, q, i); + dp = mulPow5InvDivPow2(mp, q, i); + dm = mulPow5InvDivPow2(mm, q, i); + e10 = q; + if (DEBUG) { + System.out.println(mv + " * 2^" + e2); + System.out.println("V+=" + dp); + System.out.println("V =" + dv); + System.out.println("V-=" + dm); + } + if (DEBUG) { + long exact = POW5_INV[q] + .multiply(BigInteger.valueOf(mv)) + .shiftRight(-e2 + q + k).longValueExact(); + System.out.println(exact + " " + POW5_INV[q].bitCount()); + if (dv != exact) { + throw new IllegalStateException(); + } + } + + if (q <= 21) { + if (mv % 5 == 0) { + dvIsTrailingZeros = multipleOfPowerOf5(mv, q); + } else if (roundingMode.acceptUpperBound(even)) { + dmIsTrailingZeros = multipleOfPowerOf5(mm, q); + } else if (multipleOfPowerOf5(mp, q)) { + dp--; + } + } + } else { + final int q = Math.max(0, ((-e2 * 732923) >>> 20) - 1); + final int i = -e2 - q; + final int k = pow5bits(i) - POW5_BITCOUNT; + final int j = q - k; + dv = mulPow5divPow2(mv, i, j); + dp = mulPow5divPow2(mp, i, j); + dm = mulPow5divPow2(mm, i, j); + e10 = q + e2; + if (DEBUG) { + System.out.println(mv + " * 5^" + (-e2) + " / 10^" + q); + } + if (q <= 1) { + dvIsTrailingZeros = true; + if (roundingMode.acceptUpperBound(even)) { + dmIsTrailingZeros = mmShift == 1; + } else { + dp--; + } + } else if (q < 63) { + dvIsTrailingZeros = (mv & ((1L << (q - 1)) - 1)) == 0; + } + } + if (DEBUG) { + System.out.println("d+=" + dp); + System.out.println("d =" + dv); + System.out.println("d-=" + dm); + System.out.println("e10=" + e10); + System.out.println("d-10=" + dmIsTrailingZeros); + System.out.println("d =" + dvIsTrailingZeros); + System.out.println("Accept upper=" + roundingMode.acceptUpperBound(even)); + System.out.println("Accept lower=" + roundingMode.acceptLowerBound(even)); + } + + // Step 4: Find the shortest decimal representation in the interval of legal representations. + // + // We do some extra work here in order to follow Float/Double.toString semantics. In particular, + // that requires printing in scientific format if and only if the exponent is between -3 and 7, + // and it requires printing at least two decimal digits. + // + // Above, we moved the decimal dot all the way to the right, so now we need to count digits to + // figure out the correct exponent for scientific notation. + final int vplength = decimalLength(dp); + int exp = e10 + vplength - 1; + + // Double.toString semantics requires using scientific notation if and only if outside this range. + boolean scientificNotation = !((exp >= -4) && (exp <= 15)); + + int removed = 0; + + int lastRemovedDigit = 0; + long output; + if (dmIsTrailingZeros || dvIsTrailingZeros) { + while (dp / 10 > dm / 10) { + if ((dp < 100) && scientificNotation) { + // Double.toString semantics requires printing at least two digits. + break; + } + dmIsTrailingZeros &= dm % 10 == 0; + dvIsTrailingZeros &= lastRemovedDigit == 0; + lastRemovedDigit = (int) (dv % 10); + dp /= 10; + dv /= 10; + dm /= 10; + removed++; + } + if (dmIsTrailingZeros && roundingMode.acceptLowerBound(even)) { + while (dm % 10 == 0) { + if ((dp < 100) && scientificNotation) { + // Double.toString semantics requires printing at least two digits. + break; + } + dvIsTrailingZeros &= lastRemovedDigit == 0; + lastRemovedDigit = (int) (dv % 10); + dp /= 10; + dv /= 10; + dm /= 10; + removed++; + } + } + if (dvIsTrailingZeros && (lastRemovedDigit == 5) && (dv % 2 == 0)) { + // Round even if the exact numbers is .....50..0. + lastRemovedDigit = 4; + } + output = dv + + ((dv == dm && !(dmIsTrailingZeros && roundingMode.acceptLowerBound(even))) || + (lastRemovedDigit >= 5) ? 1 : 0); + } else { + while (dp / 10 > dm / 10) { + if ((dp < 100) && scientificNotation) { + // Double.toString semantics requires printing at least two digits. + break; + } + lastRemovedDigit = (int) (dv % 10); + dp /= 10; + dv /= 10; + dm /= 10; + removed++; + } + output = dv + ((dv == dm || (lastRemovedDigit >= 5)) ? 1 : 0); + } + int olength = vplength - removed; + + if (DEBUG) { + System.out.println("LAST_REMOVED_DIGIT=" + lastRemovedDigit); + System.out.println("VP=" + dp); + System.out.println("VR=" + dv); + System.out.println("VM=" + dm); + System.out.println("O=" + output); + System.out.println("OLEN=" + olength); + System.out.println("EXP=" + exp); + } + + // Step 5: Print the decimal representation. + // We follow Double.toString semantics here. + char[] result = new char[24]; + int index = 0; + if (sign) { + result[index++] = '-'; + } + + // Values in the interval [1E-3, 1E7) are special. + if (scientificNotation) { + // Print in the format x.xxxxxE-yy. + for (int i = 0; i < olength - 1; i++) { + int c = (int) (output % 10); + output /= 10; + result[index + olength - i] = (char) ('0' + c); + } + result[index] = (char) ('0' + output % 10); + result[index + 1] = '.'; + index += olength + 1; + if (olength == 1) { + result[index++] = '0'; + } + + // Print 'E', the exponent sign, and the exponent, which has at most three digits. + result[index++] = 'e'; + if (exp < 0) { + result[index++] = '-'; + exp = -exp; + } + if (exp >= 100) { + result[index++] = (char) ('0' + exp / 100); + exp %= 100; + result[index++] = (char) ('0' + exp / 10); + } else if (exp >= 10) { + result[index++] = (char) ('0' + exp / 10); + } + result[index++] = (char) ('0' + exp % 10); + return new String(result, 0, index); + } else { + if (exp < 0) { + // Decimal dot is before any of the digits. + result[index++] = '0'; + result[index++] = '.'; + for (int i = -1; i > exp; i--) { + result[index++] = '0'; + } + int current = index; + for (int i = 0; i < olength; i++) { + result[current + olength - i - 1] = (char) ('0' + output % 10); + output /= 10; + index++; + } + } else if (exp + 1 >= olength) { + // Decimal dot is after any of the digits. + for (int i = 0; i < olength; i++) { + result[index + olength - i - 1] = (char) ('0' + output % 10); + output /= 10; + } + index += olength; + for (int i = olength; i < exp + 1; i++) { + result[index++] = '0'; + } + result[index++] = '.'; + result[index++] = '0'; + } else { + // Decimal dot is somewhere between the digits. + int current = index + 1; + for (int i = 0; i < olength; i++) { + if (olength - i - 1 == exp) { + result[current + olength - i - 1] = '.'; + current--; + } + result[current + olength - i - 1] = (char) ('0' + output % 10); + output /= 10; + } + index += olength + 1; + } + return new String(result, 0, index); + } + } + + private static int pow5bits(int e) { + return ((e * 1217359) >>> 19) + 1; + } + + private static int decimalLength(long v) { + if (v >= 1000000000000000000L) { + return 19; + } + if (v >= 100000000000000000L) { + return 18; + } + if (v >= 10000000000000000L) { + return 17; + } + if (v >= 1000000000000000L) { + return 16; + } + if (v >= 100000000000000L) { + return 15; + } + if (v >= 10000000000000L) { + return 14; + } + if (v >= 1000000000000L) { + return 13; + } + if (v >= 100000000000L) { + return 12; + } + if (v >= 10000000000L) { + return 11; + } + if (v >= 1000000000L) { + return 10; + } + if (v >= 100000000L) { + return 9; + } + if (v >= 10000000L) { + return 8; + } + if (v >= 1000000L) { + return 7; + } + if (v >= 100000L) { + return 6; + } + if (v >= 10000L) { + return 5; + } + if (v >= 1000L) { + return 4; + } + if (v >= 100L) { + return 3; + } + if (v >= 10L) { + return 2; + } + return 1; + } + + private static boolean multipleOfPowerOf5(long value, int q) { + return pow5Factor(value) >= q; + } + + private static int pow5Factor(long value) { + // We want to find the largest power of 5 that divides value. + if ((value % 5) != 0) { + return 0; + } + if ((value % 25) != 0) { + return 1; + } + if ((value % 125) != 0) { + return 2; + } + if ((value % 625) != 0) { + return 3; + } + int count = 4; + value /= 625; + while (value > 0) { + if (value % 5 != 0) { + return count; + } + value /= 5; + count++; + } + throw new IllegalArgumentException("" + value); + } + + /** + * Compute the high digits of m * 5^p / 10^q = m * 5^(p - q) / 2^q = m * 5^i / 2^j, with q chosen + * such that m * 5^i / 2^j has sufficiently many decimal digits to represent the original floating + * point number. + */ + private static long mulPow5divPow2(long m, int i, int j) { + // m has at most 55 bits. + long mHigh = m >>> 31; + long mLow = m & 0x7fffffff; + long bits13 = mHigh * POW5_SPLIT[i][0]; // 124 + long bits03 = mLow * POW5_SPLIT[i][0]; // 93 + long bits12 = mHigh * POW5_SPLIT[i][1]; // 93 + long bits02 = mLow * POW5_SPLIT[i][1]; // 62 + long bits11 = mHigh * POW5_SPLIT[i][2]; // 62 + long bits01 = mLow * POW5_SPLIT[i][2]; // 31 + long bits10 = mHigh * POW5_SPLIT[i][3]; // 31 + long bits00 = mLow * POW5_SPLIT[i][3]; // 0 + int actualShift = j - 3 * 31 - 21; + if (actualShift < 0) { + throw new IllegalArgumentException("" + actualShift); + } + return (((((( + ((bits00 >>> 31) + bits01 + bits10) >>> 31) + + bits02 + bits11) >>> 31) + + bits03 + bits12) >>> 21) + + (bits13 << 10)) >>> actualShift; + } + + /** + * Compute the high digits of m / 5^i / 2^j such that the result is accurate to at least 9 + * decimal digits. i and j are already chosen appropriately. + */ + private static long mulPow5InvDivPow2(long m, int i, int j) { + // m has at most 55 bits. + long mHigh = m >>> 31; + long mLow = m & 0x7fffffff; + long bits13 = mHigh * POW5_INV_SPLIT[i][0]; + long bits03 = mLow * POW5_INV_SPLIT[i][0]; + long bits12 = mHigh * POW5_INV_SPLIT[i][1]; + long bits02 = mLow * POW5_INV_SPLIT[i][1]; + long bits11 = mHigh * POW5_INV_SPLIT[i][2]; + long bits01 = mLow * POW5_INV_SPLIT[i][2]; + long bits10 = mHigh * POW5_INV_SPLIT[i][3]; + long bits00 = mLow * POW5_INV_SPLIT[i][3]; + + int actualShift = j - 3 * 31 - 21; + if (actualShift < 0) { + throw new IllegalArgumentException("" + actualShift); + } + return (((((( + ((bits00 >>> 31) + bits01 + bits10) >>> 31) + + bits02 + bits11) >>> 31) + + bits03 + bits12) >>> 21) + + (bits13 << 10)) >>> actualShift; + } +} \ No newline at end of file diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/common/RyuFloat.java b/fe/fe-core/src/main/java/com/starrocks/sql/common/RyuFloat.java new file mode 100644 index 00000000000000..71e6149f578c04 --- /dev/null +++ b/fe/fe-core/src/main/java/com/starrocks/sql/common/RyuFloat.java @@ -0,0 +1,433 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.starrocks.sql.common; + +import java.math.BigInteger; + +/** + * Ref: https://github.com/ulfjack/ryu + * An implementation of Ryu for float. + */ +public final class RyuFloat { + private static boolean DEBUG = false; + + private static final int FLOAT_MANTISSA_BITS = 23; + private static final int FLOAT_MANTISSA_MASK = (1 << FLOAT_MANTISSA_BITS) - 1; + + private static final int FLOAT_EXPONENT_BITS = 8; + private static final int FLOAT_EXPONENT_MASK = (1 << FLOAT_EXPONENT_BITS) - 1; + private static final int FLOAT_EXPONENT_BIAS = (1 << (FLOAT_EXPONENT_BITS - 1)) - 1; + + private static final long LOG10_2_DENOMINATOR = 10000000L; + private static final long LOG10_2_NUMERATOR = (long) (LOG10_2_DENOMINATOR * Math.log10(2)); + + private static final long LOG10_5_DENOMINATOR = 10000000L; + private static final long LOG10_5_NUMERATOR = (long) (LOG10_5_DENOMINATOR * Math.log10(5)); + + private static final long LOG2_5_DENOMINATOR = 10000000L; + private static final long LOG2_5_NUMERATOR = (long) (LOG2_5_DENOMINATOR * (Math.log(5) / Math.log(2))); + + private static final int POS_TABLE_SIZE = 47; + private static final int INV_TABLE_SIZE = 31; + + // Only for debugging. + private static final BigInteger[] POW5 = new BigInteger[POS_TABLE_SIZE]; + private static final BigInteger[] POW5_INV = new BigInteger[INV_TABLE_SIZE]; + + private static final int POW5_BITCOUNT = 61; + private static final int POW5_HALF_BITCOUNT = 31; + private static final int[][] POW5_SPLIT = new int[POS_TABLE_SIZE][2]; + + private static final int POW5_INV_BITCOUNT = 59; + private static final int POW5_INV_HALF_BITCOUNT = 31; + private static final int[][] POW5_INV_SPLIT = new int[INV_TABLE_SIZE][2]; + + static { + BigInteger mask = BigInteger.valueOf(1).shiftLeft(POW5_HALF_BITCOUNT).subtract(BigInteger.ONE); + BigInteger maskInv = BigInteger.valueOf(1).shiftLeft(POW5_INV_HALF_BITCOUNT).subtract(BigInteger.ONE); + for (int i = 0; i < Math.max(POW5.length, POW5_INV.length); i++) { + BigInteger pow = BigInteger.valueOf(5).pow(i); + int pow5len = pow.bitLength(); + int expectedPow5Bits = pow5bits(i); + if (expectedPow5Bits != pow5len) { + throw new IllegalStateException(pow5len + " != " + expectedPow5Bits); + } + if (i < POW5.length) { + POW5[i] = pow; + } + if (i < POW5_SPLIT.length) { + POW5_SPLIT[i][0] = pow.shiftRight(pow5len - POW5_BITCOUNT + POW5_HALF_BITCOUNT).intValueExact(); + POW5_SPLIT[i][1] = pow.shiftRight(pow5len - POW5_BITCOUNT).and(mask).intValueExact(); + } + + if (i < POW5_INV.length) { + int j = pow5len - 1 + POW5_INV_BITCOUNT; + BigInteger inv = BigInteger.ONE.shiftLeft(j).divide(pow).add(BigInteger.ONE); + POW5_INV[i] = inv; + POW5_INV_SPLIT[i][0] = inv.shiftRight(POW5_INV_HALF_BITCOUNT).intValueExact(); + POW5_INV_SPLIT[i][1] = inv.and(maskInv).intValueExact(); + } + } + } + + public static String floatToString(float value) { + return floatToString(value, RoundingMode.ROUND_EVEN); + } + + public static String floatToString(float value, RoundingMode roundingMode) { + // Step 1: Decode the floating point number, and unify normalized and subnormal cases. + // First, handle all the trivial cases. + if (Float.isNaN(value)) { + return "NaN"; + } + if (value == Float.POSITIVE_INFINITY) { + return "Infinity"; + } + if (value == Float.NEGATIVE_INFINITY) { + return "-Infinity"; + } + int bits = Float.floatToIntBits(value); + if (bits == 0) { + return "0.0"; + } + if (bits == 0x80000000) { + return "-0.0"; + } + + // Otherwise extract the mantissa and exponent bits and run the full algorithm. + int ieeeExponent = (bits >> FLOAT_MANTISSA_BITS) & FLOAT_EXPONENT_MASK; + int ieeeMantissa = bits & FLOAT_MANTISSA_MASK; + // By default, the correct mantissa starts with a 1, except for denormal numbers. + int e2; + int m2; + if (ieeeExponent == 0) { + e2 = 1 - FLOAT_EXPONENT_BIAS - FLOAT_MANTISSA_BITS; + m2 = ieeeMantissa; + } else { + e2 = ieeeExponent - FLOAT_EXPONENT_BIAS - FLOAT_MANTISSA_BITS; + m2 = ieeeMantissa | (1 << FLOAT_MANTISSA_BITS); + } + + boolean sign = bits < 0; + if (DEBUG) { + System.out.println("IN=" + Long.toBinaryString(bits)); + System.out.println(" S=" + (sign ? "-" : "+") + " E=" + e2 + " M=" + m2); + } + + // Step 2: Determine the interval of legal decimal representations. + boolean even = (m2 & 1) == 0; + int mv = 4 * m2; + int mp = 4 * m2 + 2; + int mm = 4 * m2 - ((m2 != (1L << FLOAT_MANTISSA_BITS)) || (ieeeExponent <= 1) ? 2 : 1); + e2 -= 2; + + if (DEBUG) { + String sv; + String sp; + String sm; + int e10; + if (e2 >= 0) { + sv = BigInteger.valueOf(mv).shiftLeft(e2).toString(); + sp = BigInteger.valueOf(mp).shiftLeft(e2).toString(); + sm = BigInteger.valueOf(mm).shiftLeft(e2).toString(); + e10 = 0; + } else { + BigInteger factor = BigInteger.valueOf(5).pow(-e2); + sv = BigInteger.valueOf(mv).multiply(factor).toString(); + sp = BigInteger.valueOf(mp).multiply(factor).toString(); + sm = BigInteger.valueOf(mm).multiply(factor).toString(); + e10 = e2; + } + + e10 += sp.length() - 1; + + System.out.println("Exact values"); + System.out.println(" m =" + mv); + System.out.println(" E =" + e10); + System.out.println(" d+=" + sp); + System.out.println(" d =" + sv); + System.out.println(" d-=" + sm); + System.out.println(" e2=" + e2); + } + + // Step 3: Convert to a decimal power base using 128-bit arithmetic. + // -151 = 1 - 127 - 23 - 2 <= e_2 - 2 <= 254 - 127 - 23 - 2 = 102 + int dp; + int dv; + int dm; + int e10; + boolean dpIsTrailingZeros; + boolean dvIsTrailingZeros; + boolean dmIsTrailingZeros; + int lastRemovedDigit = 0; + if (e2 >= 0) { + // Compute m * 2^e_2 / 10^q = m * 2^(e_2 - q) / 5^q + int q = (int) (e2 * LOG10_2_NUMERATOR / LOG10_2_DENOMINATOR); + int k = POW5_INV_BITCOUNT + pow5bits(q) - 1; + int i = -e2 + q + k; + dv = (int) mulPow5InvDivPow2(mv, q, i); + dp = (int) mulPow5InvDivPow2(mp, q, i); + dm = (int) mulPow5InvDivPow2(mm, q, i); + if (q != 0 && ((dp - 1) / 10 <= dm / 10)) { + // We need to know one removed digit even if we are not going to loop below. We could use + // q = X - 1 above, except that would require 33 bits for the result, and we've found that + // 32-bit arithmetic is faster even on 64-bit machines. + int l = POW5_INV_BITCOUNT + pow5bits(q - 1) - 1; + lastRemovedDigit = (int) (mulPow5InvDivPow2(mv, q - 1, -e2 + q - 1 + l) % 10); + } + e10 = q; + if (DEBUG) { + System.out.println(mv + " * 2^" + e2 + " / 10^" + q); + } + + dpIsTrailingZeros = pow5Factor(mp) >= q; + dvIsTrailingZeros = pow5Factor(mv) >= q; + dmIsTrailingZeros = pow5Factor(mm) >= q; + } else { + // Compute m * 5^(-e_2) / 10^q = m * 5^(-e_2 - q) / 2^q + int q = (int) (-e2 * LOG10_5_NUMERATOR / LOG10_5_DENOMINATOR); + int i = -e2 - q; + int k = pow5bits(i) - POW5_BITCOUNT; + int j = q - k; + dv = (int) mulPow5divPow2(mv, i, j); + dp = (int) mulPow5divPow2(mp, i, j); + dm = (int) mulPow5divPow2(mm, i, j); + if (q != 0 && ((dp - 1) / 10 <= dm / 10)) { + j = q - 1 - (pow5bits(i + 1) - POW5_BITCOUNT); + lastRemovedDigit = (int) (mulPow5divPow2(mv, i + 1, j) % 10); + } + e10 = q + e2; // Note: e2 and e10 are both negative here. + if (DEBUG) { + System.out.println( + mv + " * 5^" + (-e2) + " / 10^" + q + " = " + mv + " * 5^" + (-e2 - q) + " / 2^" + q); + } + + dpIsTrailingZeros = 1 >= q; + dvIsTrailingZeros = (q < FLOAT_MANTISSA_BITS) && (mv & ((1 << (q - 1)) - 1)) == 0; + dmIsTrailingZeros = (mm % 2 == 1 ? 0 : 1) >= q; + } + if (DEBUG) { + System.out.println("Actual values"); + System.out.println(" d+=" + dp); + System.out.println(" d =" + dv); + System.out.println(" d-=" + dm); + System.out.println(" last removed=" + lastRemovedDigit); + System.out.println(" e10=" + e10); + System.out.println(" d+10=" + dpIsTrailingZeros); + System.out.println(" d =" + dvIsTrailingZeros); + System.out.println(" d-10=" + dmIsTrailingZeros); + } + + // Step 4: Find the shortest decimal representation in the interval of legal representations. + // + // We do some extra work here in order to follow Float/Double.toString semantics. In particular, + // that requires printing in scientific format if and only if the exponent is between -3 and 7, + // and it requires printing at least two decimal digits. + // + // Above, we moved the decimal dot all the way to the right, so now we need to count digits to + // figure out the correct exponent for scientific notation. + int dplength = decimalLength(dp); + int exp = e10 + dplength - 1; + + // Float.toString semantics requires using scientific notation if and only if outside this range. + boolean scientificNotation = !((exp >= -4) && (exp <= 7)); + + int removed = 0; + if (dpIsTrailingZeros && !roundingMode.acceptUpperBound(even)) { + dp--; + } + + while (dp / 10 > dm / 10) { + if ((dp < 100) && scientificNotation) { + // We print at least two digits, so we might as well stop now. + break; + } + dmIsTrailingZeros &= dm % 10 == 0; + dp /= 10; + lastRemovedDigit = dv % 10; + dv /= 10; + dm /= 10; + removed++; + } + if (dmIsTrailingZeros && roundingMode.acceptLowerBound(even)) { + while (dm % 10 == 0) { + if ((dp < 100) && scientificNotation) { + // We print at least two digits, so we might as well stop now. + break; + } + dp /= 10; + lastRemovedDigit = dv % 10; + dv /= 10; + dm /= 10; + removed++; + } + } + + if (dvIsTrailingZeros && (lastRemovedDigit == 5) && (dv % 2 == 0)) { + // Round down not up if the number ends in X50000 and the number is even. + lastRemovedDigit = 4; + } + int output = dv + + ((dv == dm && !(dmIsTrailingZeros && roundingMode.acceptLowerBound(even))) || (lastRemovedDigit >= 5) ? + 1 : 0); + int olength = dplength - removed; + + if (DEBUG) { + System.out.println("Actual values after loop"); + System.out.println(" d+=" + dp); + System.out.println(" d =" + dv); + System.out.println(" d-=" + dm); + System.out.println(" last removed=" + lastRemovedDigit); + System.out.println(" e10=" + e10); + System.out.println(" d+10=" + dpIsTrailingZeros); + System.out.println(" d-10=" + dmIsTrailingZeros); + System.out.println(" output=" + output); + System.out.println(" output_length=" + olength); + System.out.println(" output_exponent=" + exp); + } + + // Step 5: Print the decimal representation. + // We follow Float.toString semantics here. + char[] result = new char[15]; + int index = 0; + if (sign) { + result[index++] = '-'; + } + + if (scientificNotation) { + // Print in the format x.xxxxxE-yy. + for (int i = 0; i < olength - 1; i++) { + int c = output % 10; + output /= 10; + result[index + olength - i] = (char) ('0' + c); + } + result[index] = (char) ('0' + output % 10); + result[index + 1] = '.'; + index += olength + 1; + if (olength == 1) { + result[index++] = '0'; + } + + // Print 'e', the exponent sign, and the exponent, which has at most two digits. + result[index++] = 'e'; + if (exp < 0) { + result[index++] = '-'; + exp = -exp; + } + if (exp >= 10) { + result[index++] = (char) ('0' + exp / 10); + } + result[index++] = (char) ('0' + exp % 10); + } else { + if (exp < 0) { + // Decimal dot is before any of the digits. + result[index++] = '0'; + result[index++] = '.'; + for (int i = -1; i > exp; i--) { + result[index++] = '0'; + } + int current = index; + for (int i = 0; i < olength; i++) { + result[current + olength - i - 1] = (char) ('0' + output % 10); + output /= 10; + index++; + } + } else if (exp + 1 >= olength) { + // Decimal dot is after any of the digits. + for (int i = 0; i < olength; i++) { + result[index + olength - i - 1] = (char) ('0' + output % 10); + output /= 10; + } + index += olength; + for (int i = olength; i < exp + 1; i++) { + result[index++] = '0'; + } + result[index++] = '.'; + result[index++] = '0'; + } else { + // Decimal dot is somewhere between the digits. + int current = index + 1; + for (int i = 0; i < olength; i++) { + if (olength - i - 1 == exp) { + result[current + olength - i - 1] = '.'; + current--; + } + result[current + olength - i - 1] = (char) ('0' + output % 10); + output /= 10; + } + index += olength + 1; + } + } + return new String(result, 0, index); + } + + private static int pow5bits(int e) { + return e == 0 ? 1 : (int) ((e * LOG2_5_NUMERATOR + LOG2_5_DENOMINATOR - 1) / LOG2_5_DENOMINATOR); + } + + /** + * Returns the exponent of the largest power of 5 that divides the given value, i.e., returns + * i such that value = 5^i * x, where x is an integer. + */ + private static int pow5Factor(int value) { + int count = 0; + while (value > 0) { + if (value % 5 != 0) { + return count; + } + value /= 5; + count++; + } + throw new IllegalArgumentException("" + value); + } + + /** + * Compute the exact result of [m * 5^(-e_2) / 10^q] = [m * 5^(-e_2 - q) / 2^q] + * = [m * [5^(p - q)/2^k] / 2^(q - k)] = [m * POW5[i] / 2^j]. + */ + private static long mulPow5divPow2(int m, int i, int j) { + if (j - POW5_HALF_BITCOUNT < 0) { + throw new IllegalArgumentException(); + } + long bits0 = m * (long) POW5_SPLIT[i][0]; + long bits1 = m * (long) POW5_SPLIT[i][1]; + return (bits0 + (bits1 >> POW5_HALF_BITCOUNT)) >> (j - POW5_HALF_BITCOUNT); + } + + /** + * Compute the exact result of [m * 2^p / 10^q] = [m * 2^(p - q) / 5 ^ q] + * = [m * [2^k / 5^q] / 2^-(p - q - k)] = [m * POW5_INV[q] / 2^j]. + */ + private static long mulPow5InvDivPow2(int m, int q, int j) { + if (j - POW5_INV_HALF_BITCOUNT < 0) { + throw new IllegalArgumentException(); + } + long bits0 = m * (long) POW5_INV_SPLIT[q][0]; + long bits1 = m * (long) POW5_INV_SPLIT[q][1]; + return (bits0 + (bits1 >> POW5_INV_HALF_BITCOUNT)) >> (j - POW5_INV_HALF_BITCOUNT); + } + + private static int decimalLength(int v) { + int length = 10; + int factor = 1000000000; + for (; length > 0; length--) { + if (v >= factor) { + break; + } + factor /= 10; + } + return length; + } +} diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/SelectConstTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/SelectConstTest.java index e4213ac18714e0..1a0ee56a7c2f51 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/plan/SelectConstTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/SelectConstTest.java @@ -14,8 +14,13 @@ package com.starrocks.sql.plan; +import com.starrocks.qe.RowBatch; +import com.starrocks.qe.scheduler.FeExecuteCoordinator; +import org.junit.Assert; import org.junit.Test; +import java.nio.charset.StandardCharsets; + public class SelectConstTest extends PlanTestBase { @Test public void testSelectConst() throws Exception { @@ -151,4 +156,91 @@ public void testDoubleCastWithoutScientificNotation() throws Exception { "WHEN false THEN 1 ELSE 2 / 3 END AS STRING ) AS BOOLEAN );"; assertPlanContains(sql, "PREDICATES: CAST('-1229625855' AS BOOLEAN)"); } + + @Test + public void testSystemVariable() throws Exception { + String sql = "SELECT @@session.auto_increment_increment, @@character_set_client, @@character_set_connection, " + + "@@character_set_results, @@character_set_server, @@collation_server, @@collation_connection, " + + "@@init_connect, @@interactive_timeout, @@language, @@license, @@lower_case_table_names, @@max_allowed_packet, " + + "@@net_write_timeout, @@performance_schema, @@query_cache_size, @@query_cache_type, @@sql_mode, " + + "@@system_time_zone, @@time_zone, @@tx_isolation, @@wait_timeout"; + String plan = getFragmentPlan(sql); + assertPlanContains(sql, "1:Project\n" + + " | : 1\n" + + " | : 'utf8'\n" + + " | : 'utf8'\n" + + " | : 'utf8'\n" + + " | : 'utf8'\n" + + " | : 'utf8_general_ci'\n" + + " | : 'utf8_general_ci'\n" + + " | : ''\n" + + " | : 3600\n" + + " | : '/starrocks/share/english/'\n" + + " | : 'Apache License 2.0'\n" + + " | : 0\n" + + " | : 33554432\n" + + " | : 60\n" + + " | : FALSE\n" + + " | : 1048576\n" + + " | : 0\n" + + " | : 'ONLY_FULL_GROUP_BY'\n" + + " | : 'Asia/Shanghai'\n" + + " | : 'Asia/Shanghai'\n" + + " | : 'REPEATABLE-READ'\n" + + " | : 28800"); + } + + @Test + public void testExecuteInFe() throws Exception { + assertFeExecuteResult("select -1", "-1"); + assertFeExecuteResult("select -123456.789", "-123456.789"); + assertFeExecuteResult("select 100000000000000", "100000000000000"); + assertFeExecuteResult("select cast(0.00001 as float)", "1.0e-5"); + assertFeExecuteResult("select cast(0.00000000000001 as double)", "1.0e-14"); + assertFeExecuteResult("select '2021-01-01'", "2021-01-01"); + assertFeExecuteResult("select '2021-01-01 01:01:01.1234'", "2021-01-01 01:01:01.1234"); + assertFeExecuteResult("select cast(1.23456000 as decimalv2)", "1.23456"); + assertFeExecuteResult("select cast(1.23456000 as DECIMAL(10, 2))", "1.23"); + assertFeExecuteResult("select cast(1.234560 as DECIMAL(12, 10))", "1.2345600000"); + assertFeExecuteResult("select '\\'abc'", "'abc"); + assertFeExecuteResult("select '\"abc'", "\"abc"); + assertFeExecuteResult("select '\\\\\\'abc'", "\\'abc"); + assertFeExecuteResult("select timediff('1000-01-02 01:01:01.123456', '1000-01-01 01:01:01.000001')", + "24:00:00"); + assertFeExecuteResult("select timediff('9999-01-02 01:01:01.123456', '1000-01-01 01:01:01.000001')", + "78883632:00:00"); + assertFeExecuteResult("select timediff('1000-01-01 01:01:01.000001', '9999-01-02 01:01:01.123456')", + "-78883632:00:01"); + } + + private void assertFeExecuteResult(String sql, String expected) throws Exception { + ExecPlan execPlan = getExecPlan(sql); + FeExecuteCoordinator coordinator = new FeExecuteCoordinator(connectContext, execPlan); + RowBatch rowBatch = coordinator.getNext(); + byte[] bytes = rowBatch.getBatch().getRows().get(0).array(); + int lengthOffset = getOffset(bytes); + String value; + if (lengthOffset == -1) { + value = "NULL"; + } else { + value = new String(bytes, lengthOffset, bytes.length - lengthOffset, StandardCharsets.UTF_8); + } + Assert.assertEquals(expected, value); + } + + private static int getOffset(byte[] bytes) { + int sw = bytes[0] & 0xff; + switch (sw) { + case 251: + return -1; + case 252: + return 3; + case 253: + return 4; + case 254: + return 9; + default: + return 1; + } + } } diff --git a/test/lib/sr_sql_lib.py b/test/lib/sr_sql_lib.py index 0b8eb2560fee01..29e145077f3a5f 100644 --- a/test/lib/sr_sql_lib.py +++ b/test/lib/sr_sql_lib.py @@ -1677,3 +1677,20 @@ def set_first_tablet_bad_and_recover(self, table_name): time.sleep(0.5) else: break + def assert_explain_contains(self, query, *expects): + """ + assert explain result contains expect string + """ + sql = "explain %s" % (query) + res = self.execute_sql(sql, True) + for expect in expects: + tools.assert_true(str(res["result"]).find(expect) > 0, "assert expect %s is not found in plan" % (expect)) + + def assert_explain_not_contains(self, query, *expects): + """ + assert explain result contains expect string + """ + sql = "explain %s" % (query) + res = self.execute_sql(sql, True) + for expect in expects: + tools.assert_true(str(res["result"]).find(expect) == -1, "assert expect %s is found in plan" % (expect)) diff --git a/test/sql/test_execute_in_fe/R/test_execute_in_fe b/test/sql/test_execute_in_fe/R/test_execute_in_fe new file mode 100644 index 00000000000000..871875fb10dd36 --- /dev/null +++ b/test/sql/test_execute_in_fe/R/test_execute_in_fe @@ -0,0 +1,89 @@ +-- name: test_execute_in_fe +set enable_constant_execute_in_fe = true; +-- result: +-- !result +select 1, -1, 1.23456, cast(1.123 as float), cast(1.123 as double), cast(10 as bigint), cast(100 as largeint), +1000000000000, 1+1, 100 * 100, 'abc', "中文", '"abc"', "'abc'", '\'abc\\', "\"abc\\", cast(1.123000000 as decimalv2), +cast(1.123 as decimal(10, 7)), date '2021-01-01', datetime '2021-01-01 00:00:00', datetime '2021-01-01 00:00:00.123456', +timediff('2028-01-01 11:25:36', '2000-11-21 12:12:12'), timediff('2000-11-21 12:12:12', '2028-01-01 11:25:36'), x'123456', x'AABBCC11'; +-- result: +1 -1 1.23456 1.123 1.123 10 100 1000000000000 2 10000 abc 中文 "abc" 'abc' 'abc\ "abc\ 1.123 1.1230000 2021-01-01 2021-01-01 00:00:00 2021-01-01 00:00:00.123456 237647:13:24 -237647:13:24 4V b'\xaa\xbb\xcc\x11' +-- !result +select 1 as a union all select 2 union all select 1000000000; +-- result: +1 +2 +1000000000 +-- !result +select @@character_set_results AS character_set_results; +-- result: +utf8 +-- !result +select cast(10000000 as float), cast(1000000000000000 as double); +-- result: +10000000.0 1000000000000000.0 +-- !result +select cast(0.00001 as float), cast(0.00001 as double); +-- result: +1e-05 1e-05 +-- !result +function: assert_explain_contains("SELECT @@session.auto_increment_increment, @@character_set_client, @@character_set_connection, @@character_set_results, @@character_set_server, @@collation_server, @@collation_connection, @@init_connect, @@interactive_timeout, @@language, @@license, @@lower_case_table_names, @@max_allowed_packet, @@net_write_timeout, @@performance_schema, @@query_cache_size, @@query_cache_type, @@sql_mode, @@system_time_zone, @@time_zone, @@tx_isolation, @@wait_timeout", "EXECUTE IN FE") +-- result: +None +-- !result +SELECT @@session.auto_increment_increment, @@character_set_client, @@character_set_connection, @@character_set_results, @@character_set_server, @@collation_server, @@collation_connection, @@init_connect, @@interactive_timeout, @@language, @@license, @@lower_case_table_names, @@max_allowed_packet, @@net_write_timeout, @@performance_schema, @@query_cache_size, @@query_cache_type, @@system_time_zone, @@time_zone, @@tx_isolation, @@wait_timeout; +-- result: +1 utf8 utf8 utf8 utf8 utf8_general_ci utf8_general_ci 3600 /starrocks/share/english/ Apache License 2.0 0 33554432 60 0 1048576 0 Asia/Shanghai Asia/Shanghai REPEATABLE-READ 28800 +-- !result +select cast(100 as time); +-- result: +0:01:00 +-- !result +select cast(1.123 as time); +-- result: +0:00:01 +-- !result +set enable_constant_execute_in_fe = false; +-- result: +-- !result +select 1, -1, 1.23456, cast(1.123 as float), cast(1.123 as double), cast(10 as bigint), cast(100 as largeint), +1000000000000, 1+1, 100 * 100, 'abc', "中文", '"abc"', "'abc'", '\'abc\\', "\"abc\\", cast(1.123000000 as decimalv2), +cast(1.123 as decimal(10, 7)), date '2021-01-01', datetime '2021-01-01 00:00:00', datetime '2021-01-01 00:00:00.123456', +timediff('2028-01-01 11:25:36', '2000-11-21 12:12:12'), timediff('2000-11-21 12:12:12', '2028-01-01 11:25:36'), x'123456', x'AABBCC11'; +-- result: +1 -1 1.23456 1.123 1.123 10 100 1000000000000 2 10000 abc 中文 "abc" 'abc' 'abc\ "abc\ 1.123 1.1230000 2021-01-01 2021-01-01 00:00:00 2021-01-01 00:00:00.123456 237647:13:24 -237647:13:24 4V b'\xaa\xbb\xcc\x11' +-- !result +function: assert_explain_not_contains("SELECT @@session.auto_increment_increment, @@character_set_client, @@character_set_connection, @@character_set_results, @@character_set_server, @@collation_server, @@collation_connection, @@init_connect, @@interactive_timeout, @@language, @@license, @@lower_case_table_names, @@max_allowed_packet, @@net_write_timeout, @@performance_schema, @@query_cache_size, @@query_cache_type, @@sql_mode, @@system_time_zone, @@time_zone, @@tx_isolation, @@wait_timeout", "EXECUTE IN FE") +-- result: +None +-- !result +SELECT @@session.auto_increment_increment, @@character_set_client, @@character_set_connection, @@character_set_results, @@character_set_server, @@collation_server, @@collation_connection, @@init_connect, @@interactive_timeout, @@language, @@license, @@lower_case_table_names, @@max_allowed_packet, @@net_write_timeout, @@performance_schema, @@query_cache_size, @@query_cache_type, @@system_time_zone, @@time_zone, @@tx_isolation, @@wait_timeout; +-- result: +1 utf8 utf8 utf8 utf8 utf8_general_ci utf8_general_ci 3600 /starrocks/share/english/ Apache License 2.0 0 33554432 60 0 1048576 0 Asia/Shanghai Asia/Shanghai REPEATABLE-READ 28800 +-- !result +select 1 as a union all select 2 union all select 1000000000; +-- result: +1 +2 +1000000000 +-- !result +select @@character_set_results AS character_set_results; +-- result: +utf8 +-- !result +select cast(10000000 as float), cast(1000000000000000 as double); +-- result: +10000000.0 1000000000000000.0 +-- !result +select cast(0.00001 as float), cast(0.00001 as double); +-- result: +1e-05 1e-05 +-- !result +select cast(100 as time); +-- result: +0:01:00 +-- !result +select cast(1.123 as time); +-- result: +0:00:01 +-- !result \ No newline at end of file diff --git a/test/sql/test_execute_in_fe/T/test_execute_in_fe b/test/sql/test_execute_in_fe/T/test_execute_in_fe new file mode 100644 index 00000000000000..23172ddd5972c3 --- /dev/null +++ b/test/sql/test_execute_in_fe/T/test_execute_in_fe @@ -0,0 +1,35 @@ +-- name: test_execute_in_fe +-- execute in fe +set enable_constant_execute_in_fe = true; +select 1, -1, 1.23456, cast(1.123 as float), cast(1.123 as double), cast(10 as bigint), cast(100 as largeint), +1000000000000, 1+1, 100 * 100, 'abc', "中文", '"abc"', "'abc'", '\'abc\\', "\"abc\\", cast(1.123000000 as decimalv2), +cast(1.123 as decimal(10, 7)), date '2021-01-01', datetime '2021-01-01 00:00:00', datetime '2021-01-01 00:00:00.123456', +timediff('2028-01-01 11:25:36', '2000-11-21 12:12:12'), timediff('2000-11-21 12:12:12', '2028-01-01 11:25:36'), x'123456', x'AABBCC11'; + +select 1 as a union all select 2 union all select 1000000000; +select @@character_set_results AS character_set_results; +select cast(10000000 as float), cast(1000000000000000 as double); +select cast(0.00001 as float), cast(0.00001 as double); + +function: assert_explain_contains("SELECT @@session.auto_increment_increment, @@character_set_client, @@character_set_connection, @@character_set_results, @@character_set_server, @@collation_server, @@collation_connection, @@init_connect, @@interactive_timeout, @@language, @@license, @@lower_case_table_names, @@max_allowed_packet, @@net_write_timeout, @@performance_schema, @@query_cache_size, @@query_cache_type, @@sql_mode, @@system_time_zone, @@time_zone, @@tx_isolation, @@wait_timeout", "EXECUTE IN FE") +SELECT @@session.auto_increment_increment, @@character_set_client, @@character_set_connection, @@character_set_results, @@character_set_server, @@collation_server, @@collation_connection, @@init_connect, @@interactive_timeout, @@language, @@license, @@lower_case_table_names, @@max_allowed_packet, @@net_write_timeout, @@performance_schema, @@query_cache_size, @@query_cache_type, @@system_time_zone, @@time_zone, @@tx_isolation, @@wait_timeout; +-- cast cannot be folding in fe +select cast(100 as time); +select cast(1.123 as time); + +-- execute in be +set enable_constant_execute_in_fe = false; +select 1, -1, 1.23456, cast(1.123 as float), cast(1.123 as double), cast(10 as bigint), cast(100 as largeint), +1000000000000, 1+1, 100 * 100, 'abc', "中文", '"abc"', "'abc'", '\'abc\\', "\"abc\\", cast(1.123000000 as decimalv2), +cast(1.123 as decimal(10, 7)), date '2021-01-01', datetime '2021-01-01 00:00:00', datetime '2021-01-01 00:00:00.123456', +timediff('2028-01-01 11:25:36', '2000-11-21 12:12:12'), timediff('2000-11-21 12:12:12', '2028-01-01 11:25:36'), x'123456', x'AABBCC11'; +function: assert_explain_not_contains("SELECT @@session.auto_increment_increment, @@character_set_client, @@character_set_connection, @@character_set_results, @@character_set_server, @@collation_server, @@collation_connection, @@init_connect, @@interactive_timeout, @@language, @@license, @@lower_case_table_names, @@max_allowed_packet, @@net_write_timeout, @@performance_schema, @@query_cache_size, @@query_cache_type, @@sql_mode, @@system_time_zone, @@time_zone, @@tx_isolation, @@wait_timeout", "EXECUTE IN FE") +SELECT @@session.auto_increment_increment, @@character_set_client, @@character_set_connection, @@character_set_results, @@character_set_server, @@collation_server, @@collation_connection, @@init_connect, @@interactive_timeout, @@language, @@license, @@lower_case_table_names, @@max_allowed_packet, @@net_write_timeout, @@performance_schema, @@query_cache_size, @@query_cache_type, @@system_time_zone, @@time_zone, @@tx_isolation, @@wait_timeout; + +select 1 as a union all select 2 union all select 1000000000; +select @@character_set_results AS character_set_results; +select cast(10000000 as float), cast(1000000000000000 as double); +select cast(0.00001 as float), cast(0.00001 as double); + +select cast(100 as time); +select cast(1.123 as time); \ No newline at end of file diff --git a/test/sql/test_others/R/test_deprecated_non_pipeline_engine b/test/sql/test_others/R/test_deprecated_non_pipeline_engine index 8d611db0f8a35b..597bb8f369620a 100644 --- a/test/sql/test_others/R/test_deprecated_non_pipeline_engine +++ b/test/sql/test_others/R/test_deprecated_non_pipeline_engine @@ -2,6 +2,9 @@ set enable_pipeline_engine=false; -- result: -- !result +set enable_constant_execute_in_fe = false; +-- result: +-- !result select 1; -- result: [REGEX].*non-pipeline engine is no longer supported since 3.2, please set enable_pipeline_engine=true.* diff --git a/test/sql/test_others/T/test_deprecated_non_pipeline_engine b/test/sql/test_others/T/test_deprecated_non_pipeline_engine index 3015148cbfdb60..8ec666fcba8ac9 100644 --- a/test/sql/test_others/T/test_deprecated_non_pipeline_engine +++ b/test/sql/test_others/T/test_deprecated_non_pipeline_engine @@ -1,5 +1,6 @@ -- name: test_deprecated_non_pipeline_engine @sequential set enable_pipeline_engine=false; +set enable_constant_execute_in_fe = false; select 1; update information_schema.be_configs set value=60 where name = 'base_compaction_check_interval_seconds'; select `value` from information_schema.be_configs where name = 'base_compaction_check_interval_seconds' limit 1;