diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java index 8f5a36924d8b..0840b87c1a51 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java @@ -102,6 +102,7 @@ import static io.trino.spi.type.DateType.DATE; import static io.trino.spi.type.DecimalType.createDecimalType; import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.RealType.REAL; import static io.trino.spi.type.SmallintType.SMALLINT; import static io.trino.spi.type.TimeType.TIME_MILLIS; import static io.trino.spi.type.TimestampType.TIMESTAMP_MILLIS; @@ -114,6 +115,7 @@ import static io.trino.spi.type.VarbinaryType.VARBINARY; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; +import static java.lang.Float.intBitsToFloat; import static java.lang.Math.floorDiv; import static java.lang.Math.floorMod; import static java.lang.Math.toIntExact; @@ -688,6 +690,14 @@ private static Optional translateValue(Object trinoNativeValue, Type typ return Optional.of(trinoNativeValue); } + if (type == REAL) { + return Optional.of(intBitsToFloat(toIntExact((long) trinoNativeValue))); + } + + if (type == DOUBLE) { + return Optional.of(trinoNativeValue); + } + if (type instanceof DecimalType decimalType) { if (decimalType.isShort()) { return Optional.of(Decimal128.parse(Decimals.toString((long) trinoNativeValue, decimalType.getScale()))); diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/TypeUtils.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/TypeUtils.java index 291be89e337a..c5d52d4236d3 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/TypeUtils.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/TypeUtils.java @@ -24,7 +24,9 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.DateType.DATE; +import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.RealType.REAL; import static io.trino.spi.type.SmallintType.SMALLINT; import static io.trino.spi.type.StandardTypes.JSON; import static io.trino.spi.type.TimeType.TIME_MILLIS; @@ -40,6 +42,8 @@ public final class TypeUtils SMALLINT, INTEGER, BIGINT, + REAL, + DOUBLE, DATE, TIME_MILLIS, TIMESTAMP_MILLIS, diff --git a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoConnectorTest.java b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoConnectorTest.java index 7291a08ca2cd..a6dcbbfdf206 100644 --- a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoConnectorTest.java +++ b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoConnectorTest.java @@ -345,6 +345,39 @@ public Object[][] predicatePushdownProvider() }; } + @Test + public void testPredicatePushdownRealType() + { + testPredicatePushdownFloatingPoint("real '1.234'"); + } + + @Test + public void testPredicatePushdownDoubleType() + { + testPredicatePushdownFloatingPoint("double '5.678'"); + } + + private void testPredicatePushdownFloatingPoint(String value) + { + try (TestTable table = new TestTable(getQueryRunner()::execute, "test_floating_point_pushdown", "AS SELECT %s col".formatted(value))) { + assertThat(query("SELECT * FROM " + table.getName() + " WHERE col = " + value)) + .isFullyPushedDown(); + assertThat(query("SELECT * FROM " + table.getName() + " WHERE col <= " + value)) + .isFullyPushedDown(); + assertThat(query("SELECT * FROM " + table.getName() + " WHERE col >= " + value)) + .isFullyPushedDown(); + assertThat(query("SELECT * FROM " + table.getName() + " WHERE col > " + value)) + .returnsEmptyResult() + .isFullyPushedDown(); + assertThat(query("SELECT * FROM " + table.getName() + " WHERE col < " + value)) + .returnsEmptyResult() + .isFullyPushedDown(); + assertThat(query("SELECT * FROM " + table.getName() + " WHERE col != " + value)) + .returnsEmptyResult() + .isNotFullyPushedDown(FilterNode.class); + } + } + @Test public void testPredicatePushdownCharWithPaddedSpace() {