Skip to content

Commit e47f1e6

Browse files
Praveen2112wendigo
authored andcommitted
Allow complex expression rewrites with double and real constant
1 parent 425fb1f commit e47f1e6

File tree

2 files changed

+36
-10
lines changed

2 files changed

+36
-10
lines changed

plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteExactNumericConstant.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,18 @@
2727
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.constant;
2828
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.type;
2929
import static io.trino.spi.type.BigintType.BIGINT;
30+
import static io.trino.spi.type.DoubleType.DOUBLE;
3031
import static io.trino.spi.type.IntegerType.INTEGER;
32+
import static io.trino.spi.type.RealType.REAL;
3133
import static io.trino.spi.type.SmallintType.SMALLINT;
3234
import static io.trino.spi.type.TinyintType.TINYINT;
3335

3436
public class RewriteExactNumericConstant
3537
implements ConnectorExpressionRule<Constant, ParameterizedExpression>
3638
{
3739
private static final Pattern<Constant> PATTERN = constant().with(type().matching(type ->
38-
type == TINYINT || type == SMALLINT || type == INTEGER || type == BIGINT || type instanceof DecimalType));
40+
type == TINYINT || type == SMALLINT || type == INTEGER || type == BIGINT
41+
|| type == REAL || type == DOUBLE || type instanceof DecimalType));
3942

4043
@Override
4144
public Pattern<Constant> getPattern()
@@ -52,7 +55,7 @@ public Optional<ParameterizedExpression> rewrite(Constant constant, Captures cap
5255
// TODO we could handle NULL values too
5356
return Optional.empty();
5457
}
55-
if (type == TINYINT || type == SMALLINT || type == INTEGER || type == BIGINT || type instanceof DecimalType) {
58+
if (type == TINYINT || type == SMALLINT || type == INTEGER || type == BIGINT || type == REAL || type == DOUBLE || type instanceof DecimalType) {
5659
return Optional.of(new ParameterizedExpression("?", ImmutableList.of(new QueryParameter(type, Optional.of(value)))));
5760
}
5861

plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,11 @@
6868
import static io.trino.spi.type.BigintType.BIGINT;
6969
import static io.trino.spi.type.BooleanType.BOOLEAN;
7070
import static io.trino.spi.type.DoubleType.DOUBLE;
71+
import static io.trino.spi.type.RealType.REAL;
7172
import static io.trino.spi.type.VarcharType.VARCHAR;
7273
import static io.trino.spi.type.VarcharType.createVarcharType;
7374
import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER;
75+
import static java.lang.Float.floatToIntBits;
7476
import static java.lang.String.format;
7577
import static org.assertj.core.api.Assertions.assertThat;
7678

@@ -93,6 +95,13 @@ public class TestPostgreSqlClient
9395
.setJdbcTypeHandle(new JdbcTypeHandle(Types.DOUBLE, Optional.of("double"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()))
9496
.build();
9597

98+
private static final JdbcColumnHandle REAL_COLUMN =
99+
JdbcColumnHandle.builder()
100+
.setColumnName("c_real")
101+
.setColumnType(REAL)
102+
.setJdbcTypeHandle(new JdbcTypeHandle(Types.REAL, Optional.of("real"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()))
103+
.build();
104+
96105
private static final JdbcColumnHandle VARCHAR_COLUMN =
97106
JdbcColumnHandle.builder()
98107
.setColumnName("c_varchar")
@@ -235,15 +244,18 @@ public void testConvertOr()
235244
Logical.Operator.OR,
236245
List.of(
237246
new Comparison(Comparison.Operator.EQUAL, new Reference(BIGINT, "c_bigint_symbol"), new Constant(BIGINT, 42L)),
238-
new Comparison(Comparison.Operator.EQUAL, new Reference(BIGINT, "c_bigint_symbol_2"), new Constant(BIGINT, 415L))))),
247+
new Comparison(Comparison.Operator.EQUAL, new Reference(REAL, "c_real_symbol"), new Constant(REAL, (long) floatToIntBits(3.14f))),
248+
new Comparison(Comparison.Operator.EQUAL, new Reference(DOUBLE, "c_double_symbol"), new Constant(DOUBLE, 3.14))))),
239249
Map.of(
240250
"c_bigint_symbol", BIGINT_COLUMN,
241-
"c_bigint_symbol_2", BIGINT_COLUMN))
251+
"c_real_symbol", REAL_COLUMN,
252+
"c_double_symbol", DOUBLE_COLUMN))
242253
.orElseThrow();
243-
assertThat(converted.expression()).isEqualTo("((\"c_bigint\") = (?)) OR ((\"c_bigint\") = (?))");
254+
assertThat(converted.expression()).isEqualTo("((\"c_bigint\") = (?)) OR ((\"c_real\") = (?)) OR ((\"c_double\") = (?))");
244255
assertThat(converted.parameters()).isEqualTo(List.of(
245256
new QueryParameter(BIGINT, Optional.of(42L)),
246-
new QueryParameter(BIGINT, Optional.of(415L))));
257+
new QueryParameter(REAL, Optional.of((long) floatToIntBits(3.14f))),
258+
new QueryParameter(DOUBLE, Optional.of(3.14))));
247259
}
248260

249261
@Test
@@ -279,8 +291,16 @@ public void testConvertComparison()
279291
Optional<ParameterizedExpression> converted = JDBC_CLIENT.convertPredicate(
280292
SESSION,
281293
translateToConnectorExpression(
282-
new Comparison(operator, new Reference(BIGINT, "c_bigint_symbol"), new Constant(BIGINT, 42L))),
283-
Map.of("c_bigint_symbol", BIGINT_COLUMN));
294+
new Logical(
295+
Logical.Operator.OR,
296+
List.of(
297+
new Comparison(operator, new Reference(BIGINT, "c_bigint_symbol"), new Constant(BIGINT, 42L)),
298+
new Comparison(operator, new Reference(REAL, "c_real_symbol"), new Constant(REAL, (long) floatToIntBits(3.14f))),
299+
new Comparison(operator, new Reference(DOUBLE, "c_double_symbol"), new Constant(DOUBLE, 3.14))))),
300+
Map.of(
301+
"c_bigint_symbol", BIGINT_COLUMN,
302+
"c_real_symbol", REAL_COLUMN,
303+
"c_double_symbol", DOUBLE_COLUMN));
284304

285305
switch (operator) {
286306
case EQUAL:
@@ -290,8 +310,11 @@ public void testConvertComparison()
290310
case GREATER_THAN:
291311
case GREATER_THAN_OR_EQUAL:
292312
assertThat(converted).isPresent();
293-
assertThat(converted.get().expression()).isEqualTo(format("(\"c_bigint\") %s (?)", operator.getValue()));
294-
assertThat(converted.get().parameters()).isEqualTo(List.of(new QueryParameter(BIGINT, Optional.of(42L))));
313+
assertThat(converted.get().expression()).isEqualTo(format("((\"c_bigint\") %1$s (?)) OR ((\"c_real\") %1$s (?)) OR ((\"c_double\") %1$s (?))", operator.getValue()));
314+
assertThat(converted.get().parameters()).isEqualTo(List.of(
315+
new QueryParameter(BIGINT, Optional.of(42L)),
316+
new QueryParameter(REAL, Optional.of((long) floatToIntBits(3.14f))),
317+
new QueryParameter(DOUBLE, Optional.of(3.14))));
295318
break;
296319
case IDENTICAL:
297320
assertThat(converted).isPresent();

0 commit comments

Comments
 (0)