From 3999f6208ff376a73f6abb6022dfa35bd146c0b0 Mon Sep 17 00:00:00 2001 From: Mayank Vadariya <48036907+mayankvadariya@users.noreply.github.com> Date: Fri, 13 Sep 2024 18:59:54 -0400 Subject: [PATCH] Convert range predicates to discrete set in Redshift --- .../jdbc/PredicatePushdownController.java | 17 ++++++++ .../trino/plugin/redshift/RedshiftClient.java | 13 +++--- .../redshift/TestRedshiftConnectorTest.java | 42 +++++++++++++++++++ 3 files changed, 66 insertions(+), 6 deletions(-) diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/PredicatePushdownController.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/PredicatePushdownController.java index 3f17b84bd606..3fdd9ff982e3 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/PredicatePushdownController.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/PredicatePushdownController.java @@ -17,9 +17,14 @@ import io.trino.spi.predicate.DiscreteValues; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.Ranges; +import io.trino.spi.predicate.ValueSet; import io.trino.spi.type.CharType; +import io.trino.spi.type.Type; import io.trino.spi.type.VarcharType; +import java.util.Collection; +import java.util.Optional; + import static com.google.common.base.Preconditions.checkArgument; import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.getDomainCompactionThreshold; import static java.util.Objects.requireNonNull; @@ -62,6 +67,18 @@ public interface PredicatePushdownController return new DomainPushdownResult(simplifiedDomain, domain); }; + static PredicatePushdownController pushdownDiscreteValues(Type type) + { + return (session, domain) -> { + Optional> expandedRange = domain.getValues().tryExpandRanges(getDomainCompactionThreshold(session)); + if (expandedRange.isPresent()) { + Domain convertedDiscreteDomain = Domain.create(ValueSet.copyOf(type, expandedRange.get()), domain.isNullAllowed()); + return new DomainPushdownResult(convertedDiscreteDomain, Domain.all(domain.getType())); + } + return FULL_PUSHDOWN.apply(session, domain); + }; + } + DomainPushdownResult apply(ConnectorSession session, Domain domain); final class DomainPushdownResult diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java index c9d8d5a0f0cc..64701f2082c4 100644 --- a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java @@ -113,7 +113,7 @@ import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_ERROR; import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_NON_TRANSIENT_ERROR; import static io.trino.plugin.jdbc.JdbcJoinPushdownUtil.implementJoinCostAware; -import static io.trino.plugin.jdbc.StandardColumnMappings.bigintColumnMapping; +import static io.trino.plugin.jdbc.PredicatePushdownController.pushdownDiscreteValues; import static io.trino.plugin.jdbc.StandardColumnMappings.bigintWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.booleanColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.booleanWriteFunction; @@ -121,14 +121,12 @@ import static io.trino.plugin.jdbc.StandardColumnMappings.decimalColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.doubleColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.doubleWriteFunction; -import static io.trino.plugin.jdbc.StandardColumnMappings.integerColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.integerWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.longDecimalReadFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.longDecimalWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.realColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.realWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.shortDecimalWriteFunction; -import static io.trino.plugin.jdbc.StandardColumnMappings.smallintColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.smallintWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.tinyintWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.varbinaryReadFunction; @@ -618,11 +616,14 @@ public Optional toColumnMapping(ConnectorSession session, Connect // case Types.TINYINT: -- Redshift doesn't support tinyint case Types.SMALLINT: - return Optional.of(smallintColumnMapping()); + // IN clause query in Redshift performs better compared to range queries, hence convert range queries to discrete set where possible. + return Optional.of(ColumnMapping.longMapping(SMALLINT, ResultSet::getShort, smallintWriteFunction(), pushdownDiscreteValues(SMALLINT))); case Types.INTEGER: - return Optional.of(integerColumnMapping()); + // IN clause query in Redshift performs better compared to range queries, hence convert range queries to discrete set where possible. + return Optional.of(ColumnMapping.longMapping(INTEGER, ResultSet::getInt, integerWriteFunction(), pushdownDiscreteValues(INTEGER))); case Types.BIGINT: - return Optional.of(bigintColumnMapping()); + // IN clause query in Redshift performs better compared to range queries, hence convert range queries to discrete set where possible. + return Optional.of(ColumnMapping.longMapping(BIGINT, ResultSet::getLong, bigintWriteFunction(), pushdownDiscreteValues(BIGINT))); case Types.REAL: return Optional.of(realColumnMapping()); diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java index ef5ab13ff086..996e2547f083 100644 --- a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java @@ -19,9 +19,15 @@ import io.airlift.units.Duration; import io.trino.Session; import io.trino.plugin.jdbc.BaseJdbcConnectorTest; +import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcTableHandle; +import io.trino.plugin.jdbc.JdbcTypeHandle; import io.trino.plugin.jdbc.RemoteDatabaseEvent; import io.trino.plugin.jdbc.RemoteDatabaseEvent.Status; import io.trino.plugin.jdbc.RemoteLogTracingEvent; +import io.trino.spi.predicate.Domain; +import io.trino.spi.predicate.TupleDomain; +import io.trino.sql.planner.plan.TableScanNode; import io.trino.testing.QueryRunner; import io.trino.testing.TestingConnectorBehavior; import io.trino.testing.sql.SqlExecutor; @@ -30,7 +36,9 @@ import org.jdbi.v3.core.Jdbi; import org.junit.jupiter.api.Test; +import java.sql.Types; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.OptionalInt; import java.util.Set; @@ -51,6 +59,8 @@ import static io.trino.plugin.redshift.TestingRedshiftServer.TEST_SCHEMA; import static io.trino.plugin.redshift.TestingRedshiftServer.executeInRedshift; import static io.trino.plugin.redshift.TestingRedshiftServer.executeWithRedshift; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.sql.planner.assertions.PlanMatchPattern.node; import static io.trino.testing.TestingNames.randomNameSuffix; import static java.lang.Math.round; import static java.lang.String.format; @@ -267,6 +277,38 @@ public void testRedshiftAddNotNullColumn() } } + @Test + public void testRangeQueryConvertedToInClauseQuery() + { + assertThat(query("SELECT regionkey FROM region WHERE regionkey >= 1 AND regionkey <= 4")) + .isFullyPushedDown(); + assertThat(query("SELECT regionkey FROM region WHERE regionkey >= 1 AND regionkey <= 4")) + .isNotFullyPushedDown(node(TableScanNode.class) + .with(TableScanNode.class, tableScanNode -> { + TupleDomain effectivePredicate = ((JdbcTableHandle) tableScanNode.getTable().connectorHandle()).getConstraint(); + TupleDomain expectedPredicate = + TupleDomain.withColumnDomains( + Map.of( + new JdbcColumnHandle.Builder() + .setColumnName("regionkey") + .setJdbcTypeHandle( + new JdbcTypeHandle( + Types.BIGINT, + Optional.of("int8"), + Optional.of(19), + Optional.of(0), + Optional.empty(), + Optional.empty())) + .setComment(Optional.of("Dynamic Column.")) + .setColumnType(BIGINT) + .setNullable(true) + .build(), + Domain.multipleValues(BIGINT, List.of(1L, 2L, 3L, 4L), false))); + assertThat(effectivePredicate).isEqualTo(expectedPredicate); + return true; + })); + } + @Test @Override public void testDelete()