diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/projection/ApplyProjectionUtil.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/projection/ApplyProjectionUtil.java index 1f83ec4b5c05..e44281e7320e 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/projection/ApplyProjectionUtil.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/projection/ApplyProjectionUtil.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.function.Predicate; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; @@ -33,31 +34,38 @@ public final class ApplyProjectionUtil private ApplyProjectionUtil() {} public static List extractSupportedProjectedColumns(ConnectorExpression expression) + { + return extractSupportedProjectedColumns(expression, connectorExpression -> true); + } + + public static List extractSupportedProjectedColumns(ConnectorExpression expression, Predicate expressionPredicate) { requireNonNull(expression, "expression is null"); ImmutableList.Builder supportedSubExpressions = ImmutableList.builder(); - fillSupportedProjectedColumns(expression, supportedSubExpressions); + fillSupportedProjectedColumns(expression, supportedSubExpressions, expressionPredicate); return supportedSubExpressions.build(); } - private static void fillSupportedProjectedColumns(ConnectorExpression expression, ImmutableList.Builder supportedSubExpressions) + private static void fillSupportedProjectedColumns(ConnectorExpression expression, ImmutableList.Builder supportedSubExpressions, Predicate expressionPredicate) { - if (isPushDownSupported(expression)) { + if (isPushdownSupported(expression, expressionPredicate)) { supportedSubExpressions.add(expression); return; } // If the whole expression is not supported, look for a partially supported projection for (ConnectorExpression child : expression.getChildren()) { - fillSupportedProjectedColumns(child, supportedSubExpressions); + fillSupportedProjectedColumns(child, supportedSubExpressions, expressionPredicate); } } @VisibleForTesting - static boolean isPushDownSupported(ConnectorExpression expression) + static boolean isPushdownSupported(ConnectorExpression expression, Predicate expressionPredicate) { - return expression instanceof Variable || - (expression instanceof FieldDereference fieldDereference && isPushDownSupported(fieldDereference.getTarget())); + return expressionPredicate.test(expression) + && (expression instanceof Variable || + (expression instanceof FieldDereference fieldDereference + && isPushdownSupported(fieldDereference.getTarget(), expressionPredicate))); } public static ProjectedColumnRepresentation createProjectedColumnRepresentation(ConnectorExpression expression) diff --git a/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/projection/TestApplyProjectionUtil.java b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/projection/TestApplyProjectionUtil.java index 1acb21ecabf5..a5a924355210 100644 --- a/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/projection/TestApplyProjectionUtil.java +++ b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/projection/TestApplyProjectionUtil.java @@ -18,10 +18,11 @@ import io.trino.spi.expression.Constant; import io.trino.spi.expression.FieldDereference; import io.trino.spi.expression.Variable; +import io.trino.spi.type.RowType; import org.testng.annotations.Test; import static io.trino.plugin.base.projection.ApplyProjectionUtil.extractSupportedProjectedColumns; -import static io.trino.plugin.base.projection.ApplyProjectionUtil.isPushDownSupported; +import static io.trino.plugin.base.projection.ApplyProjectionUtil.isPushdownSupported; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.RowType.field; import static io.trino.spi.type.RowType.rowType; @@ -32,6 +33,8 @@ public class TestApplyProjectionUtil { private static final ConnectorExpression ROW_OF_ROW_VARIABLE = new Variable("a", rowType(field("b", rowType(field("c", INTEGER))))); + private static final ConnectorExpression LEAF_DOTTED_ROW_OF_ROW_VARIABLE = new Variable("a", rowType(field("b", rowType(field("c.x", INTEGER))))); + private static final ConnectorExpression MID_DOTTED_ROW_OF_ROW_VARIABLE = new Variable("a", rowType(field("b.x", rowType(field("c", INTEGER))))); private static final ConnectorExpression ONE_LEVEL_DEREFERENCE = new FieldDereference( rowType(field("c", INTEGER)), @@ -43,16 +46,46 @@ public class TestApplyProjectionUtil ONE_LEVEL_DEREFERENCE, 0); + private static final ConnectorExpression LEAF_DOTTED_ONE_LEVEL_DEREFERENCE = new FieldDereference( + rowType(field("c.x", INTEGER)), + LEAF_DOTTED_ROW_OF_ROW_VARIABLE, + 0); + + private static final ConnectorExpression LEAF_DOTTED_TWO_LEVEL_DEREFERENCE = new FieldDereference( + INTEGER, + LEAF_DOTTED_ONE_LEVEL_DEREFERENCE, + 0); + + private static final ConnectorExpression MID_DOTTED_ONE_LEVEL_DEREFERENCE = new FieldDereference( + rowType(field("c.x", INTEGER)), + MID_DOTTED_ROW_OF_ROW_VARIABLE, + 0); + + private static final ConnectorExpression MID_DOTTED_TWO_LEVEL_DEREFERENCE = new FieldDereference( + INTEGER, + MID_DOTTED_ONE_LEVEL_DEREFERENCE, + 0); + private static final ConnectorExpression INT_VARIABLE = new Variable("a", INTEGER); private static final ConnectorExpression CONSTANT = new Constant(5, INTEGER); @Test public void testIsProjectionSupported() { - assertTrue(isPushDownSupported(ONE_LEVEL_DEREFERENCE)); - assertTrue(isPushDownSupported(TWO_LEVEL_DEREFERENCE)); - assertTrue(isPushDownSupported(INT_VARIABLE)); - assertFalse(isPushDownSupported(CONSTANT)); + assertTrue(isPushdownSupported(ONE_LEVEL_DEREFERENCE, connectorExpression -> true)); + assertTrue(isPushdownSupported(TWO_LEVEL_DEREFERENCE, connectorExpression -> true)); + assertTrue(isPushdownSupported(INT_VARIABLE, connectorExpression -> true)); + assertFalse(isPushdownSupported(CONSTANT, connectorExpression -> true)); + + assertFalse(isPushdownSupported(ONE_LEVEL_DEREFERENCE, connectorExpression -> false)); + assertFalse(isPushdownSupported(TWO_LEVEL_DEREFERENCE, connectorExpression -> false)); + assertFalse(isPushdownSupported(INT_VARIABLE, connectorExpression -> false)); + assertFalse(isPushdownSupported(CONSTANT, connectorExpression -> false)); + + assertTrue(isPushdownSupported(LEAF_DOTTED_ONE_LEVEL_DEREFERENCE, this::isSupportedForPushDown)); + assertFalse(isPushdownSupported(LEAF_DOTTED_TWO_LEVEL_DEREFERENCE, this::isSupportedForPushDown)); + assertFalse(isPushdownSupported(MID_DOTTED_ONE_LEVEL_DEREFERENCE, this::isSupportedForPushDown)); + assertFalse(isPushdownSupported(MID_DOTTED_TWO_LEVEL_DEREFERENCE, this::isSupportedForPushDown)); } @Test @@ -62,5 +95,32 @@ public void testExtractSupportedProjectionColumns() assertEquals(extractSupportedProjectedColumns(TWO_LEVEL_DEREFERENCE), ImmutableList.of(TWO_LEVEL_DEREFERENCE)); assertEquals(extractSupportedProjectedColumns(INT_VARIABLE), ImmutableList.of(INT_VARIABLE)); assertEquals(extractSupportedProjectedColumns(CONSTANT), ImmutableList.of()); + + assertEquals(extractSupportedProjectedColumns(ONE_LEVEL_DEREFERENCE, connectorExpression -> false), ImmutableList.of()); + assertEquals(extractSupportedProjectedColumns(TWO_LEVEL_DEREFERENCE, connectorExpression -> false), ImmutableList.of()); + assertEquals(extractSupportedProjectedColumns(INT_VARIABLE, connectorExpression -> false), ImmutableList.of()); + assertEquals(extractSupportedProjectedColumns(CONSTANT, connectorExpression -> false), ImmutableList.of()); + + // Partial supported projection + assertEquals(extractSupportedProjectedColumns(LEAF_DOTTED_ONE_LEVEL_DEREFERENCE, this::isSupportedForPushDown), ImmutableList.of(LEAF_DOTTED_ONE_LEVEL_DEREFERENCE)); + assertEquals(extractSupportedProjectedColumns(LEAF_DOTTED_TWO_LEVEL_DEREFERENCE, this::isSupportedForPushDown), ImmutableList.of(LEAF_DOTTED_ONE_LEVEL_DEREFERENCE)); + assertEquals(extractSupportedProjectedColumns(MID_DOTTED_ONE_LEVEL_DEREFERENCE, this::isSupportedForPushDown), ImmutableList.of(MID_DOTTED_ROW_OF_ROW_VARIABLE)); + assertEquals(extractSupportedProjectedColumns(MID_DOTTED_TWO_LEVEL_DEREFERENCE, this::isSupportedForPushDown), ImmutableList.of(MID_DOTTED_ROW_OF_ROW_VARIABLE)); + } + + /** + * This method is used to simulate the behavior when the field passed in the connectorExpression might not supported for pushdown. + */ + private boolean isSupportedForPushDown(ConnectorExpression connectorExpression) + { + if (connectorExpression instanceof FieldDereference fieldDereference) { + RowType rowType = (RowType) fieldDereference.getTarget().getType(); + RowType.Field field = rowType.getFields().get(fieldDereference.getField()); + String fieldName = field.getName().get(); + if (fieldName.contains(".") || fieldName.contains("$")) { + return false; + } + } + return true; } } diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoClientConfig.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoClientConfig.java index b607c14e64e5..9a76138cdbe1 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoClientConfig.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoClientConfig.java @@ -14,6 +14,7 @@ package io.trino.plugin.mongodb; import io.airlift.configuration.Config; +import io.airlift.configuration.ConfigDescription; import io.airlift.configuration.ConfigSecuritySensitive; import io.airlift.configuration.DefunctConfig; import io.airlift.configuration.LegacyConfig; @@ -44,6 +45,7 @@ public class MongoClientConfig private WriteConcernType writeConcern = WriteConcernType.ACKNOWLEDGED; private String requiredReplicaSetName; private String implicitRowFieldPrefix = "_pos"; + private boolean projectionPushDownEnabled = true; @NotNull public String getSchemaCollection() @@ -237,4 +239,17 @@ public MongoClientConfig setMaxConnectionIdleTime(int maxConnectionIdleTime) this.maxConnectionIdleTime = maxConnectionIdleTime; return this; } + + public boolean isProjectionPushdownEnabled() + { + return projectionPushDownEnabled; + } + + @Config("mongodb.projection-pushdown-enabled") + @ConfigDescription("Read only required fields from a row type") + public MongoClientConfig setProjectionPushdownEnabled(boolean projectionPushDownEnabled) + { + this.projectionPushDownEnabled = projectionPushDownEnabled; + return this; + } } diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoClientModule.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoClientModule.java index 930cf06707e4..60441e7ecca7 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoClientModule.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoClientModule.java @@ -23,6 +23,7 @@ import com.mongodb.client.MongoClient; import com.mongodb.client.MongoClients; import io.airlift.configuration.AbstractConfigurationAwareModule; +import io.trino.plugin.base.session.SessionPropertiesProvider; import io.trino.plugin.mongodb.ptf.Query; import io.trino.spi.function.table.ConnectorTableFunction; import io.trino.spi.type.TypeManager; @@ -44,6 +45,7 @@ public void setup(Binder binder) binder.bind(MongoSplitManager.class).in(Scopes.SINGLETON); binder.bind(MongoPageSourceProvider.class).in(Scopes.SINGLETON); binder.bind(MongoPageSinkProvider.class).in(Scopes.SINGLETON); + newSetBinder(binder, SessionPropertiesProvider.class).addBinding().to(MongoSessionProperties.class).in(Scopes.SINGLETON); configBinder(binder).bindConfig(MongoClientConfig.class); newSetBinder(binder, MongoClientSettingConfigurator.class); diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoColumnHandle.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoColumnHandle.java index 71b62bfcc84b..2692d2d48cd9 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoColumnHandle.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoColumnHandle.java @@ -14,12 +14,16 @@ package io.trino.plugin.mongodb; import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.base.Joiner; +import com.google.common.collect.ImmutableList; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ColumnMetadata; import io.trino.spi.type.Type; import org.bson.Document; +import java.util.List; import java.util.Objects; import java.util.Optional; @@ -28,28 +32,41 @@ public class MongoColumnHandle implements ColumnHandle { - private final String name; + private final String baseName; + private final List dereferenceNames; private final Type type; private final boolean hidden; + // Represent if the field is inside a DBRef type + private final boolean dbRefField; private final Optional comment; @JsonCreator public MongoColumnHandle( - @JsonProperty("name") String name, + @JsonProperty("baseName") String baseName, + @JsonProperty("dereferenceNames") List dereferenceNames, @JsonProperty("columnType") Type type, @JsonProperty("hidden") boolean hidden, + @JsonProperty("dbRefField") boolean dbRefField, @JsonProperty("comment") Optional comment) { - this.name = requireNonNull(name, "name is null"); + this.baseName = requireNonNull(baseName, "baseName is null"); + this.dereferenceNames = ImmutableList.copyOf(requireNonNull(dereferenceNames, "dereferenceNames is null")); this.type = requireNonNull(type, "type is null"); this.hidden = hidden; + this.dbRefField = dbRefField; this.comment = requireNonNull(comment, "comment is null"); } @JsonProperty - public String getName() + public String getBaseName() { - return name; + return baseName; + } + + @JsonProperty + public List getDereferenceNames() + { + return dereferenceNames; } @JsonProperty("columnType") @@ -64,6 +81,15 @@ public boolean isHidden() return hidden; } + /** + * This method may return a wrong value when row type use the same field names and types as dbref. + */ + @JsonProperty + public boolean isDbRefField() + { + return dbRefField; + } + @JsonProperty public Optional getComment() { @@ -73,25 +99,42 @@ public Optional getComment() public ColumnMetadata toColumnMetadata() { return ColumnMetadata.builder() - .setName(name) + .setName(getQualifiedName()) .setType(type) .setHidden(hidden) .setComment(comment) .build(); } + @JsonIgnore + public String getQualifiedName() + { + return Joiner.on('.') + .join(ImmutableList.builder() + .add(baseName) + .addAll(dereferenceNames) + .build()); + } + + @JsonIgnore + public boolean isBaseColumn() + { + return dereferenceNames.isEmpty(); + } + public Document getDocument() { - return new Document().append("name", name) + return new Document().append("name", getQualifiedName()) .append("type", type.getTypeSignature().toString()) .append("hidden", hidden) + .append("dbRefField", dbRefField) .append("comment", comment.orElse(null)); } @Override public int hashCode() { - return Objects.hash(name, type, hidden, comment); + return Objects.hash(baseName, dereferenceNames, type, hidden, dbRefField, comment); } @Override @@ -104,15 +147,17 @@ public boolean equals(Object obj) return false; } MongoColumnHandle other = (MongoColumnHandle) obj; - return Objects.equals(name, other.name) && + return Objects.equals(baseName, other.baseName) && + Objects.equals(dereferenceNames, other.dereferenceNames) && Objects.equals(type, other.type) && Objects.equals(hidden, other.hidden) && + Objects.equals(dbRefField, other.dbRefField) && Objects.equals(comment, other.comment); } @Override public String toString() { - return name + ":" + type; + return getQualifiedName() + ":" + type; } } diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoConnector.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoConnector.java index a6449ca1a0be..0b9caff7ff5d 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoConnector.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoConnector.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableSet; import com.google.inject.Inject; +import io.trino.plugin.base.session.SessionPropertiesProvider; import io.trino.spi.connector.Connector; import io.trino.spi.connector.ConnectorMetadata; import io.trino.spi.connector.ConnectorPageSinkProvider; @@ -23,13 +24,16 @@ import io.trino.spi.connector.ConnectorSplitManager; import io.trino.spi.connector.ConnectorTransactionHandle; import io.trino.spi.function.table.ConnectorTableFunction; +import io.trino.spi.session.PropertyMetadata; import io.trino.spi.transaction.IsolationLevel; +import java.util.List; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.spi.transaction.IsolationLevel.READ_UNCOMMITTED; import static io.trino.spi.transaction.IsolationLevel.checkConnectorSupports; import static java.util.Objects.requireNonNull; @@ -42,6 +46,7 @@ public class MongoConnector private final MongoPageSourceProvider pageSourceProvider; private final MongoPageSinkProvider pageSinkProvider; private final Set connectorTableFunctions; + private final List> sessionProperties; private final ConcurrentMap transactions = new ConcurrentHashMap<>(); @@ -51,13 +56,17 @@ public MongoConnector( MongoSplitManager splitManager, MongoPageSourceProvider pageSourceProvider, MongoPageSinkProvider pageSinkProvider, - Set connectorTableFunctions) + Set connectorTableFunctions, + Set sessionPropertiesProviders) { this.mongoSession = mongoSession; this.splitManager = requireNonNull(splitManager, "splitManager is null"); this.pageSourceProvider = requireNonNull(pageSourceProvider, "pageSourceProvider is null"); this.pageSinkProvider = requireNonNull(pageSinkProvider, "pageSinkProvider is null"); this.connectorTableFunctions = ImmutableSet.copyOf(requireNonNull(connectorTableFunctions, "connectorTableFunctions is null")); + this.sessionProperties = sessionPropertiesProviders.stream() + .flatMap(sessionPropertiesProvider -> sessionPropertiesProvider.getSessionProperties().stream()) + .collect(toImmutableList()); } @Override @@ -115,6 +124,12 @@ public Set getTableFunctions() return connectorTableFunctions; } + @Override + public List> getSessionProperties() + { + return sessionProperties; + } + @Override public void shutdown() { diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMetadata.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMetadata.java index 94bd3b4ba417..d8c5d8b031ae 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMetadata.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMetadata.java @@ -15,14 +15,17 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Streams; import com.google.common.io.Closer; import com.mongodb.client.MongoCollection; import io.airlift.log.Logger; import io.airlift.slice.Slice; +import io.trino.plugin.base.projection.ApplyProjectionUtil; import io.trino.plugin.mongodb.MongoIndex.MongodbIndexKey; import io.trino.plugin.mongodb.ptf.Query.QueryFunctionHandle; import io.trino.spi.TrinoException; +import io.trino.spi.connector.Assignment; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ColumnMetadata; import io.trino.spi.connector.ConnectorInsertTableHandle; @@ -39,12 +42,16 @@ import io.trino.spi.connector.LimitApplicationResult; import io.trino.spi.connector.LocalProperty; import io.trino.spi.connector.NotFoundException; +import io.trino.spi.connector.ProjectionApplicationResult; import io.trino.spi.connector.RetryMode; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SchemaTablePrefix; import io.trino.spi.connector.SortingProperty; import io.trino.spi.connector.TableFunctionApplicationResult; import io.trino.spi.connector.TableNotFoundException; +import io.trino.spi.expression.ConnectorExpression; +import io.trino.spi.expression.FieldDereference; +import io.trino.spi.expression.Variable; import io.trino.spi.function.table.ConnectorTableFunctionHandle; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; @@ -74,6 +81,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.MoreCollectors.onlyElement; import static com.mongodb.client.model.Aggregates.lookup; @@ -83,6 +91,13 @@ import static com.mongodb.client.model.Filters.ne; import static com.mongodb.client.model.Projections.exclude; import static io.trino.plugin.base.TemporaryTables.generateTemporaryTableName; +import static io.trino.plugin.base.projection.ApplyProjectionUtil.ProjectedColumnRepresentation; +import static io.trino.plugin.base.projection.ApplyProjectionUtil.extractSupportedProjectedColumns; +import static io.trino.plugin.base.projection.ApplyProjectionUtil.replaceWithNewVariables; +import static io.trino.plugin.mongodb.MongoSession.COLLECTION_NAME; +import static io.trino.plugin.mongodb.MongoSession.DATABASE_NAME; +import static io.trino.plugin.mongodb.MongoSession.ID; +import static io.trino.plugin.mongodb.MongoSessionProperties.isProjectionPushdownEnabled; import static io.trino.plugin.mongodb.TypeUtils.isPushdownSupportedType; import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; @@ -92,11 +107,13 @@ import static io.trino.spi.type.RealType.REAL; import static io.trino.spi.type.SmallintType.SMALLINT; import static io.trino.spi.type.TinyintType.TINYINT; +import static io.trino.spi.type.VarcharType.VARCHAR; import static java.lang.Math.toIntExact; import static java.lang.String.format; import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; import static java.util.stream.Collectors.toList; public class MongoMetadata @@ -178,7 +195,7 @@ public Map getColumnHandles(ConnectorSession session, Conn ImmutableMap.Builder columnHandles = ImmutableMap.builder(); for (MongoColumnHandle columnHandle : columns) { - columnHandles.put(columnHandle.getName().toLowerCase(ENGLISH), columnHandle); + columnHandles.put(columnHandle.getBaseName().toLowerCase(ENGLISH), columnHandle); } return columnHandles.buildOrThrow(); } @@ -240,7 +257,7 @@ public void setColumnComment(ConnectorSession session, ConnectorTableHandle tabl { MongoTableHandle table = (MongoTableHandle) tableHandle; MongoColumnHandle column = (MongoColumnHandle) columnHandle; - mongoSession.setColumnComment(table, column.getName(), comment); + mongoSession.setColumnComment(table, column.getBaseName(), comment); } @Override @@ -262,13 +279,13 @@ public void addColumn(ConnectorSession session, ConnectorTableHandle tableHandle @Override public void renameColumn(ConnectorSession session, ConnectorTableHandle tableHandle, ColumnHandle source, String target) { - mongoSession.renameColumn(((MongoTableHandle) tableHandle), ((MongoColumnHandle) source).getName(), target); + mongoSession.renameColumn(((MongoTableHandle) tableHandle), ((MongoColumnHandle) source).getBaseName(), target); } @Override public void dropColumn(ConnectorSession session, ConnectorTableHandle tableHandle, ColumnHandle column) { - mongoSession.dropColumn(((MongoTableHandle) tableHandle), ((MongoColumnHandle) column).getName()); + mongoSession.dropColumn(((MongoTableHandle) tableHandle), ((MongoColumnHandle) column).getBaseName()); } @Override @@ -279,7 +296,7 @@ public void setColumnType(ConnectorSession session, ConnectorTableHandle tableHa if (!canChangeColumnType(column.getType(), type)) { throw new TrinoException(NOT_SUPPORTED, "Cannot change type from %s to %s".formatted(column.getType(), type)); } - mongoSession.setColumnType(table, column.getName(), type); + mongoSession.setColumnType(table, column.getBaseName(), type); } private static boolean canChangeColumnType(Type sourceType, Type newType) @@ -375,7 +392,7 @@ public ConnectorOutputTableHandle beginCreateTable(ConnectorSession session, Con Optional.empty()); } - MongoColumnHandle pageSinkIdColumn = buildPageSinkIdColumn(columns.stream().map(MongoColumnHandle::getName).collect(toImmutableSet())); + MongoColumnHandle pageSinkIdColumn = buildPageSinkIdColumn(columns.stream().map(MongoColumnHandle::getBaseName).collect(toImmutableSet())); List allTemporaryTableColumns = ImmutableList.builderWithExpectedSize(columns.size() + 1) .addAll(columns) .add(pageSinkIdColumn) @@ -388,7 +405,7 @@ public ConnectorOutputTableHandle beginCreateTable(ConnectorSession session, Con remoteTableName, handleColumns, Optional.of(temporaryTable.getCollectionName()), - Optional.of(pageSinkIdColumn.getName())); + Optional.of(pageSinkIdColumn.getBaseName())); } @Override @@ -410,7 +427,7 @@ public ConnectorInsertTableHandle beginInsert(ConnectorSession session, Connecto List columns = table.getColumns(); List handleColumns = columns.stream() .filter(column -> !column.isHidden()) - .peek(column -> validateColumnNameForInsert(column.getName())) + .peek(column -> validateColumnNameForInsert(column.getBaseName())) .collect(toImmutableList()); if (retryMode == RetryMode.NO_RETRIES) { @@ -420,7 +437,7 @@ public ConnectorInsertTableHandle beginInsert(ConnectorSession session, Connecto Optional.empty(), Optional.empty()); } - MongoColumnHandle pageSinkIdColumn = buildPageSinkIdColumn(columns.stream().map(MongoColumnHandle::getName).collect(toImmutableSet())); + MongoColumnHandle pageSinkIdColumn = buildPageSinkIdColumn(columns.stream().map(MongoColumnHandle::getBaseName).collect(toImmutableSet())); List allColumns = ImmutableList.builderWithExpectedSize(columns.size() + 1) .addAll(columns) .add(pageSinkIdColumn) @@ -435,7 +452,7 @@ public ConnectorInsertTableHandle beginInsert(ConnectorSession session, Connecto handle.getRemoteTableName(), handleColumns, Optional.of(temporaryTable.getCollectionName()), - Optional.of(pageSinkIdColumn.getName())); + Optional.of(pageSinkIdColumn.getBaseName())); } @Override @@ -462,7 +479,7 @@ private void finishInsert( try { // Create the temporary page sink ID table RemoteTableName pageSinkIdsTable = new RemoteTableName(temporaryTable.getDatabaseName(), generateTemporaryTableName(session)); - MongoColumnHandle pageSinkIdColumn = new MongoColumnHandle(pageSinkIdColumnName, TRINO_PAGE_SINK_ID_COLUMN_TYPE, false, Optional.empty()); + MongoColumnHandle pageSinkIdColumn = new MongoColumnHandle(pageSinkIdColumnName, ImmutableList.of(), TRINO_PAGE_SINK_ID_COLUMN_TYPE, false, false, Optional.empty()); mongoSession.createTable(pageSinkIdsTable, ImmutableList.of(pageSinkIdColumn), Optional.empty()); closer.register(() -> mongoSession.dropTable(pageSinkIdsTable)); @@ -494,7 +511,7 @@ private void finishInsert( @Override public ColumnHandle getMergeRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle) { - return new MongoColumnHandle("$merge_row_id", BIGINT, true, Optional.empty()); + return new MongoColumnHandle("$merge_row_id", ImmutableList.of(), BIGINT, true, false, Optional.empty()); } @Override @@ -558,7 +575,13 @@ public Optional> applyLimit(Connect } return Optional.of(new LimitApplicationResult<>( - new MongoTableHandle(handle.getSchemaTableName(), handle.getRemoteTableName(), handle.getFilter(), handle.getConstraint(), OptionalInt.of(toIntExact(limit))), + new MongoTableHandle( + handle.getSchemaTableName(), + handle.getRemoteTableName(), + handle.getFilter(), + handle.getConstraint(), + handle.getProjectedColumns(), + OptionalInt.of(toIntExact(limit))), true, false)); } @@ -605,11 +628,162 @@ public Optional> applyFilter(C handle.getRemoteTableName(), handle.getFilter(), newDomain, + handle.getProjectedColumns(), handle.getLimit()); return Optional.of(new ConstraintApplicationResult<>(handle, remainingFilter, false)); } + @Override + public Optional> applyProjection( + ConnectorSession session, + ConnectorTableHandle handle, + List projections, + Map assignments) + { + if (!isProjectionPushdownEnabled(session)) { + return Optional.empty(); + } + // Create projected column representations for supported sub expressions. Simple column references and chain of + // dereferences on a variable are supported right now. + Set projectedExpressions = projections.stream() + .flatMap(expression -> extractSupportedProjectedColumns(expression, MongoMetadata::isSupportedForPushdown).stream()) + .collect(toImmutableSet()); + + Map columnProjections = projectedExpressions.stream() + .collect(toImmutableMap(identity(), ApplyProjectionUtil::createProjectedColumnRepresentation)); + + MongoTableHandle mongoTableHandle = (MongoTableHandle) handle; + + // all references are simple variables + if (columnProjections.values().stream().allMatch(ProjectedColumnRepresentation::isVariable)) { + Set projectedColumns = assignments.values().stream() + .map(MongoColumnHandle.class::cast) + .collect(toImmutableSet()); + if (mongoTableHandle.getProjectedColumns().equals(projectedColumns)) { + return Optional.empty(); + } + List assignmentsList = assignments.entrySet().stream() + .map(assignment -> new Assignment( + assignment.getKey(), + assignment.getValue(), + ((MongoColumnHandle) assignment.getValue()).getType())) + .collect(toImmutableList()); + + return Optional.of(new ProjectionApplicationResult<>( + mongoTableHandle.withProjectedColumns(projectedColumns), + projections, + assignmentsList, + false)); + } + + Map newAssignments = new HashMap<>(); + ImmutableMap.Builder newVariablesBuilder = ImmutableMap.builder(); + ImmutableSet.Builder projectedColumnsBuilder = ImmutableSet.builder(); + + for (Map.Entry entry : columnProjections.entrySet()) { + ConnectorExpression expression = entry.getKey(); + ProjectedColumnRepresentation projectedColumn = entry.getValue(); + + MongoColumnHandle baseColumnHandle = (MongoColumnHandle) assignments.get(projectedColumn.getVariable().getName()); + MongoColumnHandle projectedColumnHandle = projectColumn(baseColumnHandle, projectedColumn.getDereferenceIndices(), expression.getType()); + String projectedColumnName = projectedColumnHandle.getQualifiedName(); + + Variable projectedColumnVariable = new Variable(projectedColumnName, expression.getType()); + Assignment newAssignment = new Assignment(projectedColumnName, projectedColumnHandle, expression.getType()); + newAssignments.putIfAbsent(projectedColumnName, newAssignment); + + newVariablesBuilder.put(expression, projectedColumnVariable); + projectedColumnsBuilder.add(projectedColumnHandle); + } + + // Modify projections to refer to new variables + Map newVariables = newVariablesBuilder.buildOrThrow(); + List newProjections = projections.stream() + .map(expression -> replaceWithNewVariables(expression, newVariables)) + .collect(toImmutableList()); + + List outputAssignments = newAssignments.values().stream().collect(toImmutableList()); + return Optional.of(new ProjectionApplicationResult<>( + mongoTableHandle.withProjectedColumns(projectedColumnsBuilder.build()), + newProjections, + outputAssignments, + false)); + } + + private static boolean isSupportedForPushdown(ConnectorExpression connectorExpression) + { + if (connectorExpression instanceof Variable) { + return true; + } + if (connectorExpression instanceof FieldDereference fieldDereference) { + RowType rowType = (RowType) fieldDereference.getTarget().getType(); + if (isDBRefField(rowType)) { + return false; + } + Field field = rowType.getFields().get(fieldDereference.getField()); + if (field.getName().isEmpty()) { + return false; + } + String fieldName = field.getName().get(); + if (fieldName.contains(".") || fieldName.contains("$")) { + return false; + } + return true; + } + return false; + } + + private static MongoColumnHandle projectColumn(MongoColumnHandle baseColumn, List indices, Type projectedColumnType) + { + if (indices.isEmpty()) { + return baseColumn; + } + ImmutableList.Builder dereferenceNamesBuilder = ImmutableList.builder(); + dereferenceNamesBuilder.addAll(baseColumn.getDereferenceNames()); + + Type type = baseColumn.getType(); + RowType parentType = null; + for (int index : indices) { + checkArgument(type instanceof RowType, "type should be Row type"); + RowType rowType = (RowType) type; + Field field = rowType.getFields().get(index); + dereferenceNamesBuilder.add(field.getName() + .orElseThrow(() -> new TrinoException(NOT_SUPPORTED, "ROW type does not have field names declared: " + rowType))); + parentType = rowType; + type = field.getType(); + } + return new MongoColumnHandle( + baseColumn.getBaseName(), + dereferenceNamesBuilder.build(), + projectedColumnType, + baseColumn.isHidden(), + isDBRefField(parentType), + baseColumn.getComment()); + } + + /** + * This method may return a wrong flag when row type use the same field names and types as dbref. + */ + private static boolean isDBRefField(Type type) + { + if (!(type instanceof RowType rowType)) { + return false; + } + requireNonNull(type, "type is null"); + // When projected field is inside DBRef type field + List fields = rowType.getFields(); + if (fields.size() != 3) { + return false; + } + return fields.get(0).getName().orElseThrow().equals(DATABASE_NAME) + && fields.get(0).getType().equals(VARCHAR) + && fields.get(1).getName().orElseThrow().equals(COLLECTION_NAME) + && fields.get(1).getType().equals(VARCHAR) + && fields.get(2).getName().orElseThrow().equals(ID); + // Id type can be of any type + } + @Override public Optional> applyTableFunction(ConnectorSession session, ConnectorTableFunctionHandle handle) { @@ -658,7 +832,7 @@ private ConnectorTableMetadata getTableMetadata(SchemaTableName tableName) private static List buildColumnHandles(ConnectorTableMetadata tableMetadata) { return tableMetadata.getColumns().stream() - .map(m -> new MongoColumnHandle(m.getName(), m.getType(), m.isHidden(), Optional.ofNullable(m.getComment()))) + .map(m -> new MongoColumnHandle(m.getName(), ImmutableList.of(), m.getType(), m.isHidden(), false, Optional.ofNullable(m.getComment()))) .collect(toList()); } @@ -680,6 +854,6 @@ private static MongoColumnHandle buildPageSinkIdColumn(Set otherColumnNa columnName = baseColumnName + "_" + suffix; suffix++; } - return new MongoColumnHandle(columnName, TRINO_PAGE_SINK_ID_COLUMN_TYPE, false, Optional.empty()); + return new MongoColumnHandle(columnName, ImmutableList.of(), TRINO_PAGE_SINK_ID_COLUMN_TYPE, false, false, Optional.empty()); } } diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSink.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSink.java index a0a910a55cff..915abcb3f34d 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSink.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSink.java @@ -116,7 +116,7 @@ public CompletableFuture appendPage(Page page) for (int channel = 0; channel < page.getChannelCount(); channel++) { MongoColumnHandle column = columns.get(channel); - doc.append(column.getName(), getObjectValue(columns.get(channel).getType(), page.getBlock(channel), position)); + doc.append(column.getBaseName(), getObjectValue(columns.get(channel).getType(), page.getBlock(channel), position)); } batch.add(doc); } diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSource.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSource.java index 73252eee1f96..f0b11b19fcad 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSource.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSource.java @@ -13,6 +13,7 @@ */ package io.trino.plugin.mongodb; +import com.google.common.collect.ImmutableList; import com.google.common.primitives.Shorts; import com.google.common.primitives.SignedBytes; import com.mongodb.DBRef; @@ -50,6 +51,7 @@ import java.util.Map; import java.util.concurrent.TimeUnit; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static io.airlift.slice.Slices.utf8Slice; @@ -81,6 +83,7 @@ import static java.lang.Float.floatToIntBits; import static java.lang.Math.multiplyExact; import static java.lang.String.join; +import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toList; public class MongoPageSource @@ -90,7 +93,7 @@ public class MongoPageSource private static final int ROWS_PER_REQUEST = 1024; private final MongoCursor cursor; - private final List columnNames; + private final List columns; private final List columnTypes; private Document currentDoc; private boolean finished; @@ -102,7 +105,7 @@ public MongoPageSource( MongoTableHandle tableHandle, List columns) { - this.columnNames = columns.stream().map(MongoColumnHandle::getName).collect(toList()); + this.columns = ImmutableList.copyOf(requireNonNull(columns, "columns is null")); this.columnTypes = columns.stream().map(MongoColumnHandle::getType).collect(toList()); this.cursor = mongoSession.execute(tableHandle, columns); currentDoc = null; @@ -148,7 +151,8 @@ public Page getNextPage() pageBuilder.declarePosition(); for (int column = 0; column < columnTypes.size(); column++) { BlockBuilder output = pageBuilder.getBlockBuilder(column); - appendTo(columnTypes.get(column), currentDoc.get(columnNames.get(column)), output); + MongoColumnHandle columnHandle = columns.get(column); + appendTo(columnTypes.get(column), getColumnValue(currentDoc, columnHandle), output); } } @@ -360,6 +364,50 @@ else if (type instanceof RowType rowType) { output.appendNull(); } + private static Object getColumnValue(Document document, MongoColumnHandle mongoColumnHandle) + { + Object value = document.get(mongoColumnHandle.getBaseName()); + if (mongoColumnHandle.isBaseColumn()) { + return value; + } + if (value instanceof DBRef dbRefValue) { + return getDbRefValue(dbRefValue, mongoColumnHandle); + } + Document documentValue = (Document) value; + for (String dereferenceName : mongoColumnHandle.getDereferenceNames()) { + // When parent field itself is null + if (documentValue == null) { + return null; + } + value = documentValue.get(dereferenceName); + if (value instanceof Document nestedDocument) { + documentValue = nestedDocument; + } + else if (value instanceof DBRef dbRefValue) { + // Assuming DBRefField is the leaf field + return getDbRefValue(dbRefValue, mongoColumnHandle); + } + } + return value; + } + + private static Object getDbRefValue(DBRef dbRefValue, MongoColumnHandle columnHandle) + { + if (columnHandle.getType() instanceof RowType) { + return dbRefValue; + } + checkArgument(columnHandle.isDbRefField(), "columnHandle is not a dbRef field: " + columnHandle); + List dereferenceNames = columnHandle.getDereferenceNames(); + checkState(!dereferenceNames.isEmpty(), "dereferenceNames is empty"); + String leafColumnName = dereferenceNames.get(dereferenceNames.size() - 1); + return switch (leafColumnName) { + case DATABASE_NAME -> dbRefValue.getDatabaseName(); + case COLLECTION_NAME -> dbRefValue.getCollectionName(); + case ID -> dbRefValue.getId(); + default -> throw new IllegalStateException("Unsupported DBRef column name: " + leafColumnName); + }; + } + @Override public void close() { diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java index 2b72365346d0..846438ae0543 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java @@ -18,6 +18,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Ordering; import com.google.common.collect.Streams; import com.google.common.primitives.Primitives; import com.google.common.primitives.Shorts; @@ -68,6 +69,7 @@ import java.time.LocalTime; import java.util.ArrayList; import java.util.Arrays; +import java.util.Comparator; import java.util.Date; import java.util.List; import java.util.Map; @@ -151,6 +153,9 @@ public class MongoSession .put("authorizedCollections", true) .buildOrThrow(); + private static final Ordering COLUMN_HANDLE_ORDERING = Ordering + .from(Comparator.comparingInt(columnHandle -> columnHandle.getDereferenceNames().size())); + private final TypeManager typeManager; private final MongoClient client; @@ -440,7 +445,7 @@ private MongoColumnHandle buildColumnHandle(Document columnMeta) Type type = typeManager.fromSqlType(typeString); - return new MongoColumnHandle(name, type, hidden, Optional.ofNullable(comment)); + return new MongoColumnHandle(name, ImmutableList.of(), type, hidden, false, Optional.ofNullable(comment)); } private List getColumnMetadata(Document doc) @@ -482,9 +487,14 @@ public long deleteDocuments(RemoteTableName remoteTableName, TupleDomain execute(MongoTableHandle tableHandle, List columns) { + Set projectedColumns = tableHandle.getProjectedColumns(); + checkArgument(projectedColumns.isEmpty() || projectedColumns.containsAll(columns), "projectedColumns must be empty or equal to columns"); + Document output = new Document(); - for (MongoColumnHandle column : columns) { - output.append(column.getName(), 1); + // Starting in MongoDB 4.4, it is illegal to project an embedded document with any of the embedded document's fields + // (https://www.mongodb.com/docs/manual/reference/limits/#mongodb-limit-Projection-Restrictions). So, Project only sufficient columns. + for (MongoColumnHandle column : projectSufficientColumns(columns)) { + output.append(column.getQualifiedName(), 1); } MongoCollection collection = getCollection(tableHandle.getRemoteTableName()); Document filter = buildFilter(tableHandle); @@ -499,6 +509,37 @@ public MongoCursor execute(MongoTableHandle tableHandle, List projectSufficientColumns(List columnHandles) + { + List sortedColumnHandles = COLUMN_HANDLE_ORDERING.sortedCopy(columnHandles); + List sufficientColumns = new ArrayList<>(); + for (MongoColumnHandle column : sortedColumnHandles) { + if (!parentColumnExists(sufficientColumns, column)) { + sufficientColumns.add(column); + } + } + return sufficientColumns; + } + + private static boolean parentColumnExists(List existingColumns, MongoColumnHandle column) + { + for (MongoColumnHandle existingColumn : existingColumns) { + List existingColumnDereferenceNames = existingColumn.getDereferenceNames(); + verify( + column.getDereferenceNames().size() >= existingColumnDereferenceNames.size(), + "Selected column's dereference size must be greater than or equal to the existing column's dereference size"); + if (existingColumn.getBaseName().equals(column.getBaseName()) + && column.getDereferenceNames().subList(0, existingColumnDereferenceNames.size()).equals(existingColumnDereferenceNames)) { + return true; + } + } + return false; + } + static Document buildFilter(MongoTableHandle table) { // Use $and operator because Document.putAll method overwrites existing entries where the key already exists @@ -525,7 +566,7 @@ static Document buildQuery(TupleDomain tupleDomain) private static Optional buildPredicate(MongoColumnHandle column, Domain domain) { - String name = column.getName(); + String name = column.getQualifiedName(); Type type = column.getType(); if (domain.getValues().isNone() && domain.isNullAllowed()) { return Optional.of(documentOf(name, isNullPredicate())); @@ -740,8 +781,8 @@ private void createTableMetadata(RemoteTableName remoteSchemaTableName, List fields = new ArrayList<>(); - if (!columns.stream().anyMatch(c -> c.getName().equals("_id"))) { - fields.add(new MongoColumnHandle("_id", OBJECT_ID, true, Optional.empty()).getDocument()); + if (!columns.stream().anyMatch(c -> c.getBaseName().equals("_id"))) { + fields.add(new MongoColumnHandle("_id", ImmutableList.of(), OBJECT_ID, true, false, Optional.empty()).getDocument()); } fields.addAll(columns.stream() diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSessionProperties.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSessionProperties.java new file mode 100644 index 000000000000..1a6036da1ecc --- /dev/null +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSessionProperties.java @@ -0,0 +1,55 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.mongodb; + +import com.google.common.collect.ImmutableList; +import com.google.inject.Inject; +import io.trino.plugin.base.session.SessionPropertiesProvider; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.session.PropertyMetadata; + +import java.util.List; + +import static io.trino.spi.session.PropertyMetadata.booleanProperty; + +public final class MongoSessionProperties + implements SessionPropertiesProvider +{ + private static final String PROJECTION_PUSHDOWN_ENABLED = "projection_pushdown_enabled"; + + private final List> sessionProperties; + + @Inject + public MongoSessionProperties(MongoClientConfig mongoConfig) + { + sessionProperties = ImmutableList.>builder() + .add(booleanProperty( + PROJECTION_PUSHDOWN_ENABLED, + "Read only required fields from a row type", + mongoConfig.isProjectionPushdownEnabled(), + false)) + .build(); + } + + @Override + public List> getSessionProperties() + { + return sessionProperties; + } + + public static boolean isProjectionPushdownEnabled(ConnectorSession session) + { + return session.getProperty(PROJECTION_PUSHDOWN_ENABLED, Boolean.class); + } +} diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoTableHandle.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoTableHandle.java index 63fdbe06629c..e52b35337748 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoTableHandle.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoTableHandle.java @@ -15,6 +15,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableSet; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.SchemaTableName; @@ -23,6 +24,7 @@ import java.util.Objects; import java.util.Optional; import java.util.OptionalInt; +import java.util.Set; import static com.google.common.base.MoreObjects.toStringHelper; import static java.util.Objects.requireNonNull; @@ -34,11 +36,12 @@ public class MongoTableHandle private final RemoteTableName remoteTableName; private final Optional filter; private final TupleDomain constraint; + private final Set projectedColumns; private final OptionalInt limit; public MongoTableHandle(SchemaTableName schemaTableName, RemoteTableName remoteTableName, Optional filter) { - this(schemaTableName, remoteTableName, filter, TupleDomain.all(), OptionalInt.empty()); + this(schemaTableName, remoteTableName, filter, TupleDomain.all(), ImmutableSet.of(), OptionalInt.empty()); } @JsonCreator @@ -47,12 +50,14 @@ public MongoTableHandle( @JsonProperty("remoteTableName") RemoteTableName remoteTableName, @JsonProperty("filter") Optional filter, @JsonProperty("constraint") TupleDomain constraint, + @JsonProperty("projectedColumns") Set projectedColumns, @JsonProperty("limit") OptionalInt limit) { this.schemaTableName = requireNonNull(schemaTableName, "schemaTableName is null"); this.remoteTableName = requireNonNull(remoteTableName, "remoteTableName is null"); this.filter = requireNonNull(filter, "filter is null"); this.constraint = requireNonNull(constraint, "constraint is null"); + this.projectedColumns = ImmutableSet.copyOf(requireNonNull(projectedColumns, "projectedColumns is null")); this.limit = requireNonNull(limit, "limit is null"); } @@ -80,16 +85,33 @@ public TupleDomain getConstraint() return constraint; } + @JsonProperty + public Set getProjectedColumns() + { + return projectedColumns; + } + @JsonProperty public OptionalInt getLimit() { return limit; } + public MongoTableHandle withProjectedColumns(Set projectedColumns) + { + return new MongoTableHandle( + schemaTableName, + remoteTableName, + filter, + constraint, + projectedColumns, + limit); + } + @Override public int hashCode() { - return Objects.hash(schemaTableName, filter, constraint, limit); + return Objects.hash(schemaTableName, filter, constraint, projectedColumns, limit); } @Override @@ -106,6 +128,7 @@ public boolean equals(Object obj) Objects.equals(this.remoteTableName, other.remoteTableName) && Objects.equals(this.filter, other.filter) && Objects.equals(this.constraint, other.constraint) && + Objects.equals(this.projectedColumns, other.projectedColumns) && Objects.equals(this.limit, other.limit); } @@ -117,6 +140,7 @@ public String toString() .add("remoteTableName", remoteTableName) .add("filter", filter) .add("constraint", constraint) + .add("projectedColumns", projectedColumns) .add("limit", limit) .toString(); } diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/ptf/Query.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/ptf/Query.java index 51227c163229..41d284b8c60e 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/ptf/Query.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/ptf/Query.java @@ -138,7 +138,7 @@ public TableFunctionAnalysis analyze( Descriptor returnedType = new Descriptor(columns.stream() .map(MongoColumnHandle.class::cast) - .map(column -> new Descriptor.Field(column.getName(), Optional.of(column.getType()))) + .map(column -> new Descriptor.Field(column.getBaseName(), Optional.of(column.getType()))) .collect(toImmutableList())); QueryFunctionHandle handle = new QueryFunctionHandle(tableHandle); diff --git a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoConnectorSmokeTest.java b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoConnectorSmokeTest.java index f084161a8602..5da1e555bc57 100644 --- a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoConnectorSmokeTest.java +++ b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoConnectorSmokeTest.java @@ -13,8 +13,13 @@ */ package io.trino.plugin.mongodb; +import com.google.common.collect.ImmutableList; import io.trino.testing.BaseConnectorSmokeTest; import io.trino.testing.TestingConnectorBehavior; +import io.trino.testing.sql.TestTable; +import org.testng.annotations.Test; + +import static org.assertj.core.api.Assertions.assertThat; public abstract class BaseMongoConnectorSmokeTest extends BaseConnectorSmokeTest @@ -43,4 +48,85 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) return super.hasBehavior(connectorBehavior); } } + + @Test + public void testProjectionPushdown() + { + try (TestTable testTable = new TestTable( + getQueryRunner()::execute, + "test_projection_pushdown_multiple_rows_", + "(id INT, nested1 ROW(child1 INT, child2 VARCHAR))", + ImmutableList.of( + "(1, ROW(10, 'a'))", + "(2, ROW(NULL, 'b'))", + "(3, ROW(30, 'c'))", + "(4, NULL)"))) { + assertThat(query("SELECT id, nested1.child1 FROM " + testTable.getName() + " WHERE nested1.child2 = 'c'")) + .matches("VALUES (3, 30)") + .isFullyPushedDown(); + } + } + + @Test + public void testReadDottedField() + { + try (TestTable testTable = new TestTable( + getQueryRunner()::execute, + "test_read_dotted_field_", + "(root ROW(\"dotted.field\" VARCHAR, field VARCHAR))", + ImmutableList.of("ROW(ROW('foo', 'bar'))"))) { + assertThat(query("SELECT root.\"dotted.field\" FROM " + testTable.getName())) + .matches("SELECT varchar 'foo'"); + + assertThat(query("SELECT root.\"dotted.field\", root.field FROM " + testTable.getName())) + .matches("SELECT varchar 'foo', varchar 'bar'"); + } + } + + @Test + public void testReadDollarPrefixedField() + { + try (TestTable testTable = new TestTable( + getQueryRunner()::execute, + "test_read_dotted_field_", + "(root ROW(\"$field1\" VARCHAR, field2 VARCHAR))", + ImmutableList.of("ROW(ROW('foo', 'bar'))"))) { + assertThat(query("SELECT root.\"$field1\" FROM " + testTable.getName())) + .matches("SELECT varchar 'foo'"); + + assertThat(query("SELECT root.\"$field1\", root.field2 FROM " + testTable.getName())) + .matches("SELECT varchar 'foo', varchar 'bar'"); + } + } + + @Test + public void testProjectionPushdownWithHighlyNestedData() + { + try (TestTable testTable = new TestTable( + getQueryRunner()::execute, + "test_projection_pushdown_highly_nested_data_", + "(id INT, row1_t ROW(f1 INT, f2 INT, row2_t ROW (f1 INT, f2 INT, row3_t ROW(f1 INT, f2 INT))))", + ImmutableList.of("(1, ROW(2, 3, ROW(4, 5, ROW(6, 7))))", + "(11, ROW(12, 13, ROW(14, 15, ROW(16, 17))))", + "(21, ROW(22, 23, ROW(24, 25, ROW(26, 27))))"))) { + // Test select projected columns, with and without their parent column + assertQuery("SELECT id, row1_t.row2_t.row3_t.f2 FROM " + testTable.getName(), "VALUES (1, 7), (11, 17), (21, 27)"); + assertQuery("SELECT id, row1_t.row2_t.row3_t.f2, CAST(row1_t AS JSON) FROM " + testTable.getName(), + "VALUES (1, 7, '{\"f1\":2,\"f2\":3,\"row2_t\":{\"f1\":4,\"f2\":5,\"row3_t\":{\"f1\":6,\"f2\":7}}}'), " + + "(11, 17, '{\"f1\":12,\"f2\":13,\"row2_t\":{\"f1\":14,\"f2\":15,\"row3_t\":{\"f1\":16,\"f2\":17}}}'), " + + "(21, 27, '{\"f1\":22,\"f2\":23,\"row2_t\":{\"f1\":24,\"f2\":25,\"row3_t\":{\"f1\":26,\"f2\":27}}}')"); + + // Test predicates on immediate child column and deeper nested column + assertQuery("SELECT id, CAST(row1_t.row2_t.row3_t AS JSON) FROM " + testTable.getName() + " WHERE row1_t.row2_t.row3_t.f2 = 27", "VALUES (21, '{\"f1\":26,\"f2\":27}')"); + assertQuery("SELECT id, CAST(row1_t.row2_t.row3_t AS JSON) FROM " + testTable.getName() + " WHERE row1_t.row2_t.row3_t.f2 > 20", "VALUES (21, '{\"f1\":26,\"f2\":27}')"); + assertQuery("SELECT id, CAST(row1_t AS JSON) FROM " + testTable.getName() + " WHERE row1_t.row2_t.row3_t.f2 = 27", + "VALUES (21, '{\"f1\":22,\"f2\":23,\"row2_t\":{\"f1\":24,\"f2\":25,\"row3_t\":{\"f1\":26,\"f2\":27}}}')"); + assertQuery("SELECT id, CAST(row1_t AS JSON) FROM " + testTable.getName() + " WHERE row1_t.row2_t.row3_t.f2 > 20", + "VALUES (21, '{\"f1\":22,\"f2\":23,\"row2_t\":{\"f1\":24,\"f2\":25,\"row3_t\":{\"f1\":26,\"f2\":27}}}')"); + + // Test predicates on parent columns + assertQuery("SELECT id, row1_t.row2_t.row3_t.f1 FROM " + testTable.getName() + " WHERE row1_t.row2_t.row3_t = ROW(16, 17)", "VALUES (11, 16)"); + assertQuery("SELECT id, row1_t.row2_t.row3_t.f1 FROM " + testTable.getName() + " WHERE row1_t = ROW(22, 23, ROW(24, 25, ROW(26, 27)))", "VALUES (21, 26)"); + } + } } diff --git a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoClientConfig.java b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoClientConfig.java index 9b9856078dfc..9576fd534ddb 100644 --- a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoClientConfig.java +++ b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoClientConfig.java @@ -42,7 +42,8 @@ public void testDefaults() .setReadPreference(ReadPreferenceType.PRIMARY) .setWriteConcern(WriteConcernType.ACKNOWLEDGED) .setRequiredReplicaSetName(null) - .setImplicitRowFieldPrefix("_pos")); + .setImplicitRowFieldPrefix("_pos") + .setProjectionPushdownEnabled(true)); } @Test @@ -65,6 +66,7 @@ public void testExplicitPropertyMappings() .put("mongodb.write-concern", "UNACKNOWLEDGED") .put("mongodb.required-replica-set", "replica_set") .put("mongodb.implicit-row-field-prefix", "_prefix") + .put("mongodb.projection-pushdown-enabled", "false") .buildOrThrow(); MongoClientConfig expected = new MongoClientConfig() @@ -82,7 +84,8 @@ public void testExplicitPropertyMappings() .setReadPreference(ReadPreferenceType.NEAREST) .setWriteConcern(WriteConcernType.UNACKNOWLEDGED) .setRequiredReplicaSetName("replica_set") - .setImplicitRowFieldPrefix("_prefix"); + .setImplicitRowFieldPrefix("_prefix") + .setProjectionPushdownEnabled(false); assertFullMapping(properties, expected); } diff --git a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoComplexTypePredicatePushDown.java b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoComplexTypePredicatePushDown.java new file mode 100644 index 000000000000..5f3de81b9cf6 --- /dev/null +++ b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoComplexTypePredicatePushDown.java @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.mongodb; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.testing.BaseComplexTypesPredicatePushDownTest; +import io.trino.testing.QueryRunner; + +import static io.trino.plugin.mongodb.MongoQueryRunner.createMongoQueryRunner; + +public class TestMongoComplexTypePredicatePushDown + extends BaseComplexTypesPredicatePushDownTest +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + MongoServer server = closeAfterClass(new MongoServer()); + return createMongoQueryRunner(server, ImmutableMap.of(), ImmutableList.of()); + } +} diff --git a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoConnectorTest.java b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoConnectorTest.java index 60f4d7757108..fc366f53f7a8 100644 --- a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoConnectorTest.java +++ b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoConnectorTest.java @@ -15,13 +15,17 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.mongodb.DBRef; import com.mongodb.client.MongoClient; import com.mongodb.client.MongoCollection; import com.mongodb.client.MongoDatabase; import com.mongodb.client.model.Collation; import com.mongodb.client.model.CreateCollectionOptions; +import io.trino.Session; +import io.trino.sql.planner.plan.FilterNode; import io.trino.sql.planner.plan.LimitNode; +import io.trino.sql.planner.plan.ProjectNode; import io.trino.testing.BaseConnectorTest; import io.trino.testing.MaterializedResult; import io.trino.testing.MaterializedRow; @@ -43,11 +47,13 @@ import java.util.Date; import java.util.Optional; import java.util.OptionalInt; +import java.util.Set; import static com.mongodb.client.model.CollationCaseFirst.LOWER; import static com.mongodb.client.model.CollationStrength.PRIMARY; import static io.trino.plugin.mongodb.MongoQueryRunner.createMongoClient; import static io.trino.plugin.mongodb.MongoQueryRunner.createMongoQueryRunner; +import static io.trino.plugin.mongodb.TypeUtils.isPushdownSupportedType; import static io.trino.spi.connector.ConnectorMetadata.MODIFYING_ROWS_MESSAGE; import static io.trino.testing.TestingNames.randomNameSuffix; import static java.lang.String.format; @@ -100,9 +106,6 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) case SUPPORTS_TRUNCATE: return false; - case SUPPORTS_DEREFERENCE_PUSHDOWN: - return false; - case SUPPORTS_RENAME_SCHEMA: return false; @@ -827,9 +830,10 @@ public void testNativeQueryNestedRow() collection.insertOne(new Document("row_field", new Document("first", new Document("second", 1)))); collection.insertOne(new Document("row_field", new Document("first", new Document("second", 2)))); - assertQuery( - "SELECT row_field.first.second FROM TABLE(mongodb.system.query(database => 'tpch', collection => '" + tableName + "', filter => '{ \"row_field.first.second\": 1 }'))", - "VALUES 1"); + assertThat(query("SELECT row_field.first.second FROM TABLE(mongodb.system.query(database => 'tpch', collection => '" + tableName + "', filter => '{ \"row_field.first.second\": 1 }'))")) + .matches("VALUES BIGINT '1'") + .isFullyPushedDown(); + assertUpdate("DROP TABLE " + tableName); } @@ -882,7 +886,8 @@ public void testNativeQueryLimit() public void testNativeQueryProjection() { assertThat(query("SELECT name FROM TABLE(mongodb.system.query(database => 'tpch', collection => 'region', filter => '{}'))")) - .matches("SELECT name FROM region"); + .matches("SELECT name FROM region") + .isFullyPushedDown(); } @Test @@ -973,7 +978,8 @@ public void testReadTopLevelDottedField() assertThat(query("SELECT \"dotted.field\" FROM test." + tableName)) .skippingTypesCheck() - .matches("SELECT NULL"); + .matches("SELECT NULL") + .isFullyPushedDown(); assertUpdate("DROP TABLE test." + tableName); } @@ -988,10 +994,12 @@ public void testReadMiddleLevelDottedField() assertThat(query("SELECT root.\"dotted.field\" FROM test." + tableName)) .skippingTypesCheck() - .matches("SELECT ROW(varchar 'foo')"); + .matches("SELECT ROW(varchar 'foo')") + .isNotFullyPushedDown(ProjectNode.class); assertThat(query("SELECT root.\"dotted.field\".leaf FROM test." + tableName)) - .matches("SELECT varchar 'foo'"); + .matches("SELECT varchar 'foo'") + .isNotFullyPushedDown(ProjectNode.class); assertUpdate("DROP TABLE test." + tableName); } @@ -1005,10 +1013,12 @@ public void testReadLeafLevelDottedField() assertUpdate("INSERT INTO test." + tableName + " SELECT ROW('foo', 'bar')", 1); assertThat(query("SELECT root.\"dotted.field\" FROM test." + tableName)) - .matches("SELECT varchar 'foo'"); + .matches("SELECT varchar 'foo'") + .isNotFullyPushedDown(ProjectNode.class); assertThat(query("SELECT root.\"dotted.field\", root.field FROM test." + tableName)) - .matches("SELECT varchar 'foo', varchar 'bar'"); + .matches("SELECT varchar 'foo', varchar 'bar'") + .isNotFullyPushedDown(ProjectNode.class); assertUpdate("DROP TABLE test." + tableName); } @@ -1022,10 +1032,12 @@ public void testReadWithDollarPrefixedFieldName() assertUpdate("INSERT INTO test." + tableName + " SELECT ROW('foo', 'bar')", 1); assertThat(query("SELECT root.\"$field1\" FROM test." + tableName)) - .matches("SELECT varchar 'foo'"); + .matches("SELECT varchar 'foo'") + .isNotFullyPushedDown(ProjectNode.class); assertThat(query("SELECT root.\"$field1\", root.field2 FROM test." + tableName)) - .matches("SELECT varchar 'foo', varchar 'bar'"); + .matches("SELECT varchar 'foo', varchar 'bar'") + .isNotFullyPushedDown(ProjectNode.class); assertUpdate("DROP TABLE test." + tableName); } @@ -1039,10 +1051,12 @@ public void testReadWithDollarInsideFieldName() assertUpdate("INSERT INTO test." + tableName + " SELECT ROW('foo', 'bar')", 1); assertThat(query("SELECT root.\"fi$ld1\" FROM test." + tableName)) - .matches("SELECT varchar 'foo'"); + .matches("SELECT varchar 'foo'") + .isNotFullyPushedDown(ProjectNode.class); assertThat(query("SELECT root.\"fi$ld1\", root.field2 FROM test." + tableName)) - .matches("SELECT varchar 'foo', varchar 'bar'"); + .matches("SELECT varchar 'foo', varchar 'bar'") + .isNotFullyPushedDown(ProjectNode.class); assertUpdate("DROP TABLE test." + tableName); } @@ -1056,7 +1070,8 @@ public void testReadDottedFieldInsideDollarPrefixedField() assertUpdate("INSERT INTO test." + tableName + " SELECT ROW(ROW('foo'))", 1); assertThat(query("SELECT root.\"$field\".\"dotted.field\" FROM test." + tableName)) - .matches("SELECT varchar 'foo'"); + .matches("SELECT varchar 'foo'") + .isNotFullyPushedDown(ProjectNode.class); assertUpdate("DROP TABLE test." + tableName); } @@ -1070,7 +1085,8 @@ public void testReadDollarPrefixedFieldInsideDottedField() assertUpdate("INSERT INTO test." + tableName + " SELECT ROW(ROW('foo'))", 1); assertThat(query("SELECT root.\"dotted.field\".\"$field\" FROM test." + tableName)) - .matches("SELECT varchar 'foo'"); + .matches("SELECT varchar 'foo'") + .isNotFullyPushedDown(ProjectNode.class); assertUpdate("DROP TABLE test." + tableName); } @@ -1084,7 +1100,8 @@ public void testPredicateOnDottedField() assertUpdate("INSERT INTO test." + tableName + " SELECT ROW('foo')", 1); assertThat(query("SELECT root.\"dotted.field\" FROM test." + tableName + " WHERE root.\"dotted.field\" = 'foo'")) - .matches("SELECT varchar 'foo'"); + .matches("SELECT varchar 'foo'") + .isNotFullyPushedDown(FilterNode.class); assertUpdate("DROP TABLE test." + tableName); } @@ -1098,11 +1115,445 @@ public void testPredicateOnDollarPrefixedField() assertUpdate("INSERT INTO test." + tableName + " SELECT ROW('foo')", 1); assertThat(query("SELECT root.\"$field\" FROM test." + tableName + " WHERE root.\"$field\" = 'foo'")) - .matches("SELECT varchar 'foo'"); + .matches("SELECT varchar 'foo'") + .isNotFullyPushedDown(FilterNode.class); + + assertUpdate("DROP TABLE test." + tableName); + } + + @Test + public void testProjectionPushdownMixedWithUnsupportedFieldName() + { + String tableName = "test_projection_pushdown_mixed_with_unsupported_field_name_" + randomNameSuffix(); + + assertUpdate("CREATE TABLE test." + tableName + " (id INT, root1 ROW(field VARCHAR, \"dotted.field\" VARCHAR), root2 ROW(field VARCHAR, \"$field\" VARCHAR))"); + assertUpdate("INSERT INTO test." + tableName + " SELECT 1, ROW('foo1', 'bar1'), ROW('foo2', 'bar2')", 1); + + assertThat(query("SELECT root1.field, root2.\"$field\" FROM test." + tableName)) + .matches("SELECT varchar 'foo1', varchar 'bar2'") + .isNotFullyPushedDown(ProjectNode.class); + + assertThat(query("SELECT root1.\"dotted.field\", root2.field FROM test." + tableName)) + .matches("SELECT varchar 'bar1', varchar 'foo2'") + .isNotFullyPushedDown(ProjectNode.class); + + assertThat(query("SELECT root1.\"dotted.field\", root2.\"$field\" FROM test." + tableName)) + .matches("SELECT varchar 'bar1', varchar 'bar2'") + .isNotFullyPushedDown(ProjectNode.class); + + assertThat(query("SELECT root1.field, root2.field FROM test." + tableName)) + .matches("SELECT varchar 'foo1', varchar 'foo2'") + .isFullyPushedDown(); + + assertUpdate("DROP TABLE test." + tableName); + } + + @Test(dataProvider = "nestedValuesProvider") + public void testFiltersOnDereferenceColumnReadsLessData(String expectedValue, String expectedType) + { + if (!isPushdownSupportedType(getQueryRunner().getTypeManager().fromSqlType(expectedType))) { + throw new SkipException("Type doesn't support filter pushdown"); + } + + Session sessionWithoutPushdown = Session.builder(getSession()) + .setCatalogSessionProperty(getSession().getCatalog().orElseThrow(), "projection_pushdown_enabled", "false") + .build(); + + try (TestTable table = new TestTable( + getQueryRunner()::execute, + "filter_on_projection_columns", + format("(col_0 ROW(col_1 %1$s, col_2 ROW(col_3 %1$s, col_4 ROW(col_5 %1$s))))", expectedType))) { + assertUpdate(format("INSERT INTO %s VALUES NULL", table.getName()), 1); + assertUpdate(format("INSERT INTO %1$s SELECT ROW(%2$s, ROW(%2$s, ROW(%2$s)))", table.getName(), expectedValue), 1); + assertUpdate(format("INSERT INTO %1$s SELECT ROW(%2$s, ROW(NULL, ROW(%2$s)))", table.getName(), expectedValue), 1); + + Set expected = ImmutableSet.of(1); + + assertQueryStats( + getSession(), + format("SELECT 1 FROM %s WHERE col_0.col_1 = %s", table.getName(), expectedValue), + statsWithPushdown -> { + long processedInputPositionWithPushdown = statsWithPushdown.getProcessedInputPositions(); + assertQueryStats( + sessionWithoutPushdown, + format("SELECT 1 FROM %s WHERE col_0.col_1 = %s", table.getName(), expectedValue), + statsWithoutPushdown -> { + assertEquals(statsWithoutPushdown.getProcessedInputPositions(), 3); + assertEquals(processedInputPositionWithPushdown, 2); + assertThat(statsWithoutPushdown.getProcessedInputPositions()).isGreaterThan(processedInputPositionWithPushdown); + }, + results -> assertEquals(results.getOnlyColumnAsSet(), expected)); + }, + results -> assertEquals(results.getOnlyColumnAsSet(), expected)); + + assertQueryStats( + getSession(), + format("SELECT 1 FROM %s WHERE col_0.col_2.col_3 = %s", table.getName(), expectedValue), + statsWithPushdown -> { + long processedInputPositionWithPushdown = statsWithPushdown.getProcessedInputPositions(); + assertQueryStats( + sessionWithoutPushdown, + format("SELECT 1 FROM %s WHERE col_0.col_2.col_3 = %s", table.getName(), expectedValue), + statsWithoutPushdown -> { + assertEquals(statsWithoutPushdown.getProcessedInputPositions(), 3); + assertEquals(processedInputPositionWithPushdown, 1); + assertThat(statsWithoutPushdown.getProcessedInputPositions()).isGreaterThan(processedInputPositionWithPushdown); + }, + results -> assertEquals(results.getOnlyColumnAsSet(), expected)); + }, + results -> assertEquals(results.getOnlyColumnAsSet(), expected)); + + assertQueryStats( + getSession(), + format("SELECT 1 FROM %s WHERE col_0.col_2.col_4.col_5 = %s", table.getName(), expectedValue), + statsWithPushdown -> { + long processedInputPositionWithPushdown = statsWithPushdown.getProcessedInputPositions(); + assertQueryStats( + sessionWithoutPushdown, + format("SELECT 1 FROM %s WHERE col_0.col_2.col_4.col_5 = %s", table.getName(), expectedValue), + statsWithoutPushdown -> { + assertEquals(statsWithoutPushdown.getProcessedInputPositions(), 3); + assertEquals(processedInputPositionWithPushdown, 2); + assertThat(statsWithoutPushdown.getProcessedInputPositions()).isGreaterThan(processedInputPositionWithPushdown); + }, + results -> assertEquals(results.getOnlyColumnAsSet(), expected)); + }, + results -> assertEquals(results.getOnlyColumnAsSet(), expected)); + } + } + + @DataProvider + public Object[][] nestedValuesProvider() + { + return new Object[][] { + {"varchar 'String type'", "varchar"}, + {"to_utf8('BinData')", "varbinary"}, + {"bigint '1234567890'", "bigint"}, + {"true", "boolean"}, + {"double '12.3'", "double"}, + {"timestamp '1970-01-01 00:00:00.000'", "timestamp(3)"}, + {"array[bigint '1']", "array(bigint)"}, + {"ObjectId('5126bc054aed4daf9e2ab772')", "ObjectId"}, + }; + } + + @Test + public void testFiltersOnDereferenceColumnReadsLessDataNativeQuery() + { + String tableName = "test_filter_on_dereference_column_reads_less_data_native_query_" + randomNameSuffix(); + + MongoCollection collection = client.getDatabase("test").getCollection(tableName); + collection.insertOne(new Document("row_field", new Document("first", new Document("second", 1)))); + collection.insertOne(new Document("row_field", new Document("first", new Document("second", null)))); + collection.insertOne(new Document("row_field", new Document("first", null))); + + assertQueryStats( + getSession(), + "SELECT row_field.first.second FROM TABLE(mongodb.system.query(database => 'test', collection => '" + tableName + "', filter => '{ \"row_field.first.second\": 1 }'))", + stats -> assertEquals(stats.getProcessedInputPositions(), 1L), + results -> assertEquals(results.getOnlyColumnAsSet(), ImmutableSet.of(1L))); + + assertUpdate("DROP TABLE test." + tableName); + } + + @Test + public void testFilterPushdownOnFieldInsideJson() + { + String tableName = "test_filter_pushdown_on_json_" + randomNameSuffix(); + assertUpdate("CREATE TABLE test." + tableName + " (id INT, col JSON)"); + + assertUpdate("INSERT INTO test." + tableName + " VALUES (1, JSON '{\"name\": { \"first\": \"Monika\", \"last\": \"Geller\" }}')", 1); + assertUpdate("INSERT INTO test." + tableName + " VALUES (2, JSON '{\"name\": { \"first\": \"Rachel\", \"last\": \"Green\" }}')", 1); + + assertThat(query("SELECT json_extract_scalar(col, '$.name.first') FROM test." + tableName + " WHERE json_extract_scalar(col, '$.name.last') = 'Geller'")) + .matches("SELECT varchar 'Monika'") + .isNotFullyPushedDown(FilterNode.class); + + assertThat(query("SELECT 1 FROM test." + tableName + " WHERE json_extract_scalar(col, '$.name.last') = 'Geller'")) + .matches("SELECT 1") + .isNotFullyPushedDown(FilterNode.class); + + assertUpdate("DROP TABLE test." + tableName); + } + + @Test + public void testProjectionPushdownWithDifferentTypeInDocuments() + { + String tableName = "test_projection_pushdown_with_different_type_in_document_" + randomNameSuffix(); + + assertUpdate("CREATE TABLE test." + tableName + " (col1 ROW(child VARCHAR))"); + + MongoCollection collection = client.getDatabase("test").getCollection(tableName); + collection.insertOne(new Document("col1", 100)); + collection.insertOne(new Document("col1", new Document("child", "value1"))); + + assertThat(query("SELECT col1.child FROM test." + tableName)) + .skippingTypesCheck() + .matches("VALUES ('value1'), (NULL)") + .isFullyPushedDown(); + + assertUpdate("DROP TABLE test." + tableName); + } + + @Test + public void testProjectionPushdownWithColumnMissingInDocument() + { + String tableName = "test_projection_pushdown_with_column_missing_in_document_" + randomNameSuffix(); + + assertUpdate("CREATE TABLE test." + tableName + " (col1 ROW(child VARCHAR))"); + + MongoCollection collection = client.getDatabase("test").getCollection(tableName); + collection.insertOne(new Document("col1", new Document("child1", "value1"))); + collection.insertOne(new Document("col1", new Document("child", "value2"))); + + assertThat(query("SELECT col1.child FROM test." + tableName)) + .skippingTypesCheck() + .matches("VALUES ('value2'), (NULL)") + .isFullyPushedDown(); + + assertUpdate("DROP TABLE test." + tableName); + } + + @Test(dataProvider = "dbRefProvider") + public void testProjectionPushdownWithDBRef(Object objectId, String expectedValue, String expectedType) + { + String tableName = "test_projection_pushdown_with_dbref_" + randomNameSuffix(); + + DBRef dbRef = new DBRef("test", "creators", objectId); + Document document = new Document() + .append("_id", new ObjectId("5126bbf64aed4daf9e2ab771")) + .append("col1", "foo") + .append("creator", dbRef) + .append("parent", new Document("child", objectId)); + + client.getDatabase("test").getCollection(tableName).insertOne(document); + + assertThat(query("SELECT parent.child, creator.databaseName, creator.collectionName, creator.id FROM test." + tableName)) + .matches("SELECT " + expectedValue + ", varchar 'test', varchar 'creators', " + expectedValue) + .isNotFullyPushedDown(ProjectNode.class); + assertQuery( + "SELECT typeof(creator) FROM test." + tableName, + "SELECT 'row(databaseName varchar, collectionName varchar, id " + expectedType + ")'"); + + assertUpdate("DROP TABLE test." + tableName); + } + + @Test(dataProvider = "dbRefProvider") + public void testProjectionPushdownWithNestedDBRef(Object objectId, String expectedValue, String expectedType) + { + String tableName = "test_projection_pushdown_with_dbref_" + randomNameSuffix(); + + DBRef dbRef = new DBRef("test", "creators", objectId); + Document document = new Document() + .append("_id", new ObjectId("5126bbf64aed4daf9e2ab771")) + .append("col1", "foo") + .append("parent", new Document() + .append("creator", dbRef) + .append("child", objectId)); + + client.getDatabase("test").getCollection(tableName).insertOne(document); + + assertThat(query("SELECT parent.child, parent.creator.databaseName, parent.creator.collectionName, parent.creator.id FROM test." + tableName)) + .matches("SELECT " + expectedValue + ", varchar 'test', varchar 'creators', " + expectedValue) + .isNotFullyPushedDown(ProjectNode.class); + assertQuery( + "SELECT typeof(parent.creator) FROM test." + tableName, + "SELECT 'row(databaseName varchar, collectionName varchar, id " + expectedType + ")'"); + + assertUpdate("DROP TABLE test." + tableName); + } + + @Test(dataProvider = "dbRefProvider") + public void testProjectionPushdownWithPredefinedDBRefKeyword(Object objectId, String expectedValue, String expectedType) + { + String tableName = "test_projection_pushdown_with_predefined_dbref_keyword_" + randomNameSuffix(); + + DBRef dbRef = new DBRef("test", "creators", objectId); + Document document = new Document() + .append("_id", new ObjectId("5126bbf64aed4daf9e2ab771")) + .append("col1", "foo") + .append("parent", new Document("id", dbRef)); + + client.getDatabase("test").getCollection(tableName).insertOne(document); + + assertThat(query("SELECT parent.id, parent.id.id FROM test." + tableName)) + .skippingTypesCheck() + .matches("SELECT row('test', 'creators', %1$s), %1$s".formatted(expectedValue)) + .isNotFullyPushedDown(ProjectNode.class); + assertQuery( + "SELECT typeof(parent.id), typeof(parent.id.id) FROM test." + tableName, + "SELECT 'row(databaseName varchar, collectionName varchar, id %1$s)', '%1$s'".formatted(expectedType)); + + assertUpdate("DROP TABLE test." + tableName); + } + + @Test(dataProvider = "dbRefAndDocumentProvider") + public void testDBRefLikeDocument(Document document1, Document document2, String expectedValue) + { + String tableName = "test_dbref_like_document_" + randomNameSuffix(); + + client.getDatabase("test").getCollection(tableName).insertOne(document1); + client.getDatabase("test").getCollection(tableName).insertOne(document2); + + assertThat(query("SELECT * FROM test." + tableName)) + .skippingTypesCheck() + .matches("VALUES" + + " ROW(ROW(varchar 'dbref_test', varchar 'dbref_creators', " + expectedValue + "))," + + " ROW(ROW(varchar 'doc_test', varchar 'doc_creators', " + expectedValue + "))") + .isFullyPushedDown(); + + assertThat(query("SELECT creator.id FROM test." + tableName)) + .skippingTypesCheck() + .matches("VALUES (%1$s), (%1$s)".formatted(expectedValue)) + .isNotFullyPushedDown(ProjectNode.class); + + assertThat(query("SELECT creator.databasename, creator.collectionname, creator.id FROM test." + tableName)) + .skippingTypesCheck() + .matches("VALUES ('doc_test', 'doc_creators', %1$s), ('dbref_test', 'dbref_creators', %1$s)".formatted(expectedValue)) + .isNotFullyPushedDown(ProjectNode.class); assertUpdate("DROP TABLE test." + tableName); } + @DataProvider + public Object[][] dbRefAndDocumentProvider() + { + Object[][] dbRefObjects = dbRefProvider(); + Object[][] objects = new Object[dbRefObjects.length * 3][]; + int i = 0; + for (Object[] dbRefObject : dbRefObjects) { + Object objectId = dbRefObject[0]; + Object expectedValue = dbRefObject[1]; + Document dbRefDocument = new Document() + .append("_id", new ObjectId("5126bbf64aed4daf9e2ab772")) + .append("creator", new DBRef("dbref_test", "dbref_creators", objectId)); + Document documentWithSameDbRefFieldOrder = new Document() + .append("_id", new ObjectId("5126bbf64aed4daf9e2ab771")) + .append("creator", new Document().append("databaseName", "doc_test").append("collectionName", "doc_creators").append("id", objectId)); + Document documentWithDifferentDbRefFieldOrder = new Document() + .append("_id", new ObjectId("5126bbf64aed4daf9e2ab771")) + .append("creator", new Document().append("collectionName", "doc_creators").append("id", objectId).append("databaseName", "doc_test")); + + objects[i++] = new Object[] {dbRefDocument, documentWithSameDbRefFieldOrder, expectedValue}; + objects[i++] = new Object[] {dbRefDocument, documentWithDifferentDbRefFieldOrder, expectedValue}; + objects[i++] = new Object[] {documentWithSameDbRefFieldOrder, dbRefDocument, expectedValue}; + } + return objects; + } + + @Test(dataProvider = "dbRefProvider") + public void testDBRefLikeDocument(Object objectId, String expectedValue, String expectedType) + { + String tableName = "test_dbref_like_document_fails_" + randomNameSuffix(); + + Document documentWithDifferentDbRefFieldOrder = new Document() + .append("_id", new ObjectId("5126bbf64aed4daf9e2ab771")) + .append("creator", new Document() + .append("databaseName", "doc_test") + .append("collectionName", "doc_creators") + .append("id", objectId)); + Document dbRefDocument = new Document() + .append("_id", new ObjectId("5126bbf64aed4daf9e2ab772")) + .append("creator", new DBRef("dbref_test", "dbref_creators", objectId)); + client.getDatabase("test").getCollection(tableName).insertOne(documentWithDifferentDbRefFieldOrder); + client.getDatabase("test").getCollection(tableName).insertOne(dbRefDocument); + + assertThat(query("SELECT * FROM test." + tableName)) + .skippingTypesCheck() + .matches("VALUES " + + " row(row('doc_test', 'doc_creators', " + expectedValue + "))," + + " row(row('dbref_test', 'dbref_creators', " + expectedValue + "))"); + + assertThat(query("SELECT creator.id FROM test." + tableName)) + .skippingTypesCheck() + .matches("VALUES " + "(%1$s), (%1$s)".formatted(expectedValue)); + + assertThat(query("SELECT creator.databasename, creator.collectionname, creator.id FROM test." + tableName)) + .skippingTypesCheck() + .matches("VALUES " + "('doc_test', 'doc_creators', %1$s), ('dbref_test', 'dbref_creators', %1$s)".formatted(expectedValue)); + + assertUpdate("DROP TABLE test." + tableName); + } + + @Test(dataProvider = "dfRefPredicateProvider") + public void testPredicateOnDBRefField(Object objectId, String expectedValue) + { + String tableName = "test_predicate_on_dbref_field_" + randomNameSuffix(); + + Document document = new Document() + .append("_id", new ObjectId("5126bbf64aed4daf9e2ab771")) + .append("creator", new DBRef("test", "creators", objectId)); + + client.getDatabase("test").getCollection(tableName).insertOne(document); + + assertThat(query("SELECT * FROM test." + tableName + " WHERE creator.id = " + expectedValue)) + .skippingTypesCheck() + .matches("SELECT ROW(varchar 'test', varchar 'creators', " + expectedValue + ")") + .isNotFullyPushedDown(FilterNode.class); + + assertThat(query("SELECT creator.id FROM test." + tableName + " WHERE creator.id = " + expectedValue)) + .skippingTypesCheck() + .matches("SELECT " + expectedValue) + .isNotFullyPushedDown(FilterNode.class); + + assertUpdate("DROP TABLE test." + tableName); + } + + @Test(dataProvider = "dfRefPredicateProvider") + public void testPredicateOnDBRefLikeDocument(Object objectId, String expectedValue) + { + String tableName = "test_predicate_on_dbref_like_document_" + randomNameSuffix(); + + Document document = new Document() + .append("_id", new ObjectId("5126bbf64aed4daf9e2ab771")) + .append("creator", new Document() + .append("databaseName", "test") + .append("collectionName", "creators") + .append("id", objectId)); + + client.getDatabase("test").getCollection(tableName).insertOne(document); + + assertThat(query("SELECT * FROM test." + tableName + " WHERE creator.id = " + expectedValue)) + .skippingTypesCheck() + .matches("SELECT ROW(varchar 'test', varchar 'creators', " + expectedValue + ")") + .isNotFullyPushedDown(FilterNode.class); + + assertThat(query("SELECT creator.id FROM test." + tableName + " WHERE creator.id = " + expectedValue)) + .skippingTypesCheck() + .matches("SELECT " + expectedValue) + .isNotFullyPushedDown(FilterNode.class); + + assertUpdate("DROP TABLE test." + tableName); + } + + @DataProvider + public Object[][] dfRefPredicateProvider() + { + return new Object[][] { + {true, "true"}, + {4, "bigint '4'"}, + {"test", "'test'"}, + {new ObjectId("6216f0c6c432d45190f25e7c"), "ObjectId('6216f0c6c432d45190f25e7c')"}, + {new Date(0), "timestamp '1970-01-01 00:00:00.000'"}, + }; + } + + @Override + @Test + public void testProjectionPushdownReadsLessData() + { + // TODO https://github.com/trinodb/trino/issues/17713 + throw new SkipException("MongoDB connector does not calculate physical data input size"); + } + + @Override + @Test + public void testProjectionPushdownPhysicalInputSize() + { + // TODO https://github.com/trinodb/trino/issues/17713 + throw new SkipException("MongoDB connector does not calculate physical data input size"); + } + @Override protected OptionalInt maxSchemaNameLength() { diff --git a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoProjectionPushdownPlans.java b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoProjectionPushdownPlans.java new file mode 100644 index 000000000000..dc85d8249468 --- /dev/null +++ b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoProjectionPushdownPlans.java @@ -0,0 +1,302 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.mongodb; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.io.Closer; +import com.mongodb.client.MongoClient; +import io.trino.Session; +import io.trino.metadata.QualifiedObjectName; +import io.trino.metadata.TableHandle; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.predicate.Domain; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.type.RowType; +import io.trino.spi.type.Type; +import io.trino.sql.planner.assertions.BasePushdownPlanTest; +import io.trino.sql.planner.assertions.PlanMatchPattern; +import io.trino.testing.LocalQueryRunner; +import org.testng.annotations.AfterClass; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.base.Predicates.equalTo; +import static io.airlift.testing.Closeables.closeAllSuppress; +import static io.trino.plugin.mongodb.MongoQueryRunner.createMongoClient; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.sql.planner.assertions.PlanMatchPattern.any; +import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; +import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; +import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; +import static io.trino.sql.planner.assertions.PlanMatchPattern.project; +import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan; +import static io.trino.sql.planner.plan.JoinNode.Type.INNER; +import static io.trino.testing.TestingNames.randomNameSuffix; +import static io.trino.testing.TestingSession.testSessionBuilder; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestMongoProjectionPushdownPlans + extends BasePushdownPlanTest +{ + private static final String CATALOG = "mongodb"; + private static final String SCHEMA = "test"; + + private Closer closer; + + @Override + protected LocalQueryRunner createLocalQueryRunner() + { + Session session = testSessionBuilder() + .setCatalog(CATALOG) + .setSchema(SCHEMA) + .build(); + + LocalQueryRunner queryRunner = LocalQueryRunner.create(session); + + closer = Closer.create(); + MongoServer server = closer.register(new MongoServer()); + MongoClient client = closer.register(createMongoClient(server)); + + try { + queryRunner.installPlugin(new MongoPlugin()); + queryRunner.createCatalog( + CATALOG, + "mongodb", + ImmutableMap.of("mongodb.connection-url", server.getConnectionString().toString())); + // Put an dummy schema collection because MongoDB doesn't support a database without collections + client.getDatabase(SCHEMA).createCollection("dummy"); + } + catch (Throwable e) { + closeAllSuppress(e, queryRunner); + throw e; + } + return queryRunner; + } + + @AfterClass(alwaysRun = true) + public final void destroy() + throws Exception + { + closer.close(); + closer = null; + } + + @Test + public void testPushdownDisabled() + { + String tableName = "test_pushdown_disabled_" + randomNameSuffix(); + + Session session = Session.builder(getQueryRunner().getDefaultSession()) + .setCatalogSessionProperty(CATALOG, "projection_pushdown_enabled", "false") + .build(); + + getQueryRunner().execute("CREATE TABLE " + tableName + " (col0) AS SELECT CAST(row(5, 6) AS row(a bigint, b bigint)) AS col0 WHERE false"); + + assertPlan( + "SELECT col0.a expr_a, col0.b expr_b FROM " + tableName, + session, + any( + project( + ImmutableMap.of("expr_1", expression("col0[1]"), "expr_2", expression("col0[2]")), + tableScan(tableName, ImmutableMap.of("col0", "col0"))))); + } + + @Test + public void testDereferencePushdown() + { + String tableName = "test_simple_projection_pushdown" + randomNameSuffix(); + QualifiedObjectName completeTableName = new QualifiedObjectName(CATALOG, SCHEMA, tableName); + + getQueryRunner().execute("CREATE TABLE " + tableName + " (col0, col1)" + + " AS SELECT CAST(row(5, 6) AS row(x BIGINT, y BIGINT)) AS col0, BIGINT '5' AS col1"); + + Session session = getQueryRunner().getDefaultSession(); + + Optional tableHandle = getTableHandle(session, completeTableName); + assertThat(tableHandle).as("expected the table handle to be present").isPresent(); + + MongoTableHandle mongoTableHandle = (MongoTableHandle) tableHandle.get().getConnectorHandle(); + Map columns = getColumnHandles(session, completeTableName); + + MongoColumnHandle column0Handle = (MongoColumnHandle) columns.get("col0"); + MongoColumnHandle column1Handle = (MongoColumnHandle) columns.get("col1"); + + MongoColumnHandle columnX = createProjectedColumnHandle(column0Handle, ImmutableList.of("x"), BIGINT); + MongoColumnHandle columnY = createProjectedColumnHandle(column0Handle, ImmutableList.of("y"), BIGINT); + + // Simple Projection pushdown + assertPlan( + "SELECT col0.x expr_x, col0.y expr_y FROM " + tableName, + any( + tableScan( + equalTo(mongoTableHandle.withProjectedColumns(Set.of(columnX, columnY))), + TupleDomain.all(), + ImmutableMap.of("col0.x", equalTo(columnX), "col0.y", equalTo(columnY))))); + + // Projection and predicate pushdown + assertPlan( + "SELECT col0.x FROM " + tableName + " WHERE col0.x = col1 + 3 and col0.y = 2", + anyTree( + filter( + "x = col1 + BIGINT '3'", + tableScan( + table -> { + MongoTableHandle actualTableHandle = (MongoTableHandle) table; + TupleDomain constraint = actualTableHandle.getConstraint(); + return actualTableHandle.getProjectedColumns().equals(ImmutableSet.of(column1Handle, columnX)) + && constraint.equals(TupleDomain.withColumnDomains(ImmutableMap.of(columnY, Domain.singleValue(BIGINT, 2L)))); + }, + TupleDomain.all(), + ImmutableMap.of("col1", equalTo(column1Handle), "x", equalTo(columnX)))))); + + // Projection and predicate pushdown with overlapping columns + assertPlan( + "SELECT col0, col0.y expr_y FROM " + tableName + " WHERE col0.x = 5", + anyTree( + tableScan( + table -> { + MongoTableHandle actualTableHandle = (MongoTableHandle) table; + TupleDomain constraint = actualTableHandle.getConstraint(); + return actualTableHandle.getProjectedColumns().equals(ImmutableSet.of(column0Handle, columnY)) + && constraint.equals(TupleDomain.withColumnDomains(ImmutableMap.of(columnX, Domain.singleValue(BIGINT, 5L)))); + }, + TupleDomain.all(), + ImmutableMap.of("col0", equalTo(column0Handle), "y", equalTo(columnY))))); + + // Projection and predicate pushdown with joins + assertPlan( + "SELECT T.col0.x, T.col0, T.col0.y FROM " + tableName + " T join " + tableName + " S on T.col1 = S.col1 WHERE T.col0.x = 2", + anyTree( + project( + ImmutableMap.of( + "expr_0_x", expression("expr_0[1]"), + "expr_0", expression("expr_0"), + "expr_0_y", expression("expr_0[2]")), + PlanMatchPattern.join(INNER, builder -> builder + .equiCriteria("t_expr_1", "s_expr_1") + .left( + anyTree( + tableScan( + table -> { + MongoTableHandle actualTableHandle = (MongoTableHandle) table; + TupleDomain constraint = actualTableHandle.getConstraint(); + Set expectedProjections = ImmutableSet.of(column0Handle, column1Handle); + TupleDomain expectedConstraint = TupleDomain.withColumnDomains( + ImmutableMap.of(columnX, Domain.singleValue(BIGINT, 2L))); + return actualTableHandle.getProjectedColumns().equals(expectedProjections) + && constraint.equals(expectedConstraint); + }, + TupleDomain.all(), + ImmutableMap.of("expr_0", equalTo(column0Handle), "t_expr_1", equalTo(column1Handle))))) + .right( + anyTree( + tableScan( + equalTo(mongoTableHandle.withProjectedColumns(Set.of(column1Handle))), + TupleDomain.all(), + ImmutableMap.of("s_expr_1", equalTo(column1Handle))))))))); + } + + @Test + public void testDereferencePushdownWithDotAndDollarContainingField() + { + String tableName = "test_dereference_pushdown_with_dot_and_dollar_containing_field_" + randomNameSuffix(); + QualifiedObjectName completeTableName = new QualifiedObjectName(CATALOG, SCHEMA, tableName); + + getQueryRunner().execute( + "CREATE TABLE " + tableName + " (id, root1) AS" + + " SELECT BIGINT '1', CAST(ROW(11, ROW(111, ROW(1111, varchar 'foo', varchar 'bar'))) AS" + + " ROW(id BIGINT, root2 ROW(id BIGINT, root3 ROW(id BIGINT, \"dotted.field\" VARCHAR, \"$name\" VARCHAR))))"); + + Session session = getQueryRunner().getDefaultSession(); + + Optional tableHandle = getTableHandle(session, completeTableName); + assertThat(tableHandle).as("expected the table handle to be present").isPresent(); + + MongoTableHandle mongoTableHandle = (MongoTableHandle) tableHandle.get().getConnectorHandle(); + Map columns = getColumnHandles(session, completeTableName); + + RowType rowType = RowType.rowType( + RowType.field("id", BIGINT), + RowType.field("dotted.field", VARCHAR), + RowType.field("$name", VARCHAR)); + + MongoColumnHandle columnRoot1 = (MongoColumnHandle) columns.get("root1"); + MongoColumnHandle columnRoot3 = createProjectedColumnHandle(columnRoot1, ImmutableList.of("root2", "root3"), rowType); + + // Dotted field will not get pushdown, But it's parent filed 'root1.root2.root3' will get pushdown + assertPlan( + "SELECT root1.root2.root3.\"dotted.field\" FROM " + tableName, + anyTree( + tableScan( + equalTo(mongoTableHandle.withProjectedColumns(Set.of(columnRoot3))), + TupleDomain.all(), + ImmutableMap.of("root1.root2.root3", equalTo(columnRoot3))))); + + // Dollar containing field will not get pushdown, But it's parent filed 'root1.root2.root3' will get pushdown + assertPlan( + "SELECT root1.root2.root3.\"$name\" FROM " + tableName, + anyTree( + tableScan( + equalTo(mongoTableHandle.withProjectedColumns(Set.of(columnRoot3))), + TupleDomain.all(), + ImmutableMap.of("root1.root2.root3", equalTo(columnRoot3))))); + + assertPlan( + "SELECT 1 FROM " + tableName + " WHERE root1.root2.root3.\"dotted.field\" = 'foo'", + anyTree( + tableScan( + table -> { + MongoTableHandle actualTableHandle = (MongoTableHandle) table; + TupleDomain constraint = actualTableHandle.getConstraint(); + return actualTableHandle.getProjectedColumns().equals(ImmutableSet.of(columnRoot3)) + && constraint.equals(TupleDomain.all()); // Predicate will not get pushdown for dollar containing field + }, + TupleDomain.all(), + ImmutableMap.of("root1.root2.root3", equalTo(columnRoot3))))); + + assertPlan( + "SELECT 1 FROM " + tableName + " WHERE root1.root2.root3.\"$name\" = 'bar'", + anyTree( + tableScan( + table -> { + MongoTableHandle actualTableHandle = (MongoTableHandle) table; + TupleDomain constraint = actualTableHandle.getConstraint(); + return actualTableHandle.getProjectedColumns().equals(ImmutableSet.of(columnRoot3)) + && constraint.equals(TupleDomain.all()); // Predicate will not get pushdown for dollar containing field + }, + TupleDomain.all(), + ImmutableMap.of("root1.root2.root3", equalTo(columnRoot3))))); + } + + private MongoColumnHandle createProjectedColumnHandle( + MongoColumnHandle baseColumnHandle, + List dereferenceNames, + Type type) + { + return new MongoColumnHandle( + baseColumnHandle.getBaseName(), + dereferenceNames, + type, + baseColumnHandle.isHidden(), + baseColumnHandle.isDbRefField(), + baseColumnHandle.getComment()); + } +} diff --git a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoSession.java b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoSession.java index 6d86fd9ac773..c0d8c0fac97e 100644 --- a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoSession.java +++ b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoSession.java @@ -19,12 +19,15 @@ import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.predicate.ValueSet; +import io.trino.spi.type.Type; import org.bson.Document; import org.testng.annotations.Test; +import java.util.List; import java.util.Optional; import static io.airlift.slice.Slices.utf8Slice; +import static io.trino.plugin.mongodb.MongoSession.projectSufficientColumns; import static io.trino.spi.predicate.Range.equal; import static io.trino.spi.predicate.Range.greaterThan; import static io.trino.spi.predicate.Range.greaterThanOrEqual; @@ -34,14 +37,17 @@ import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; import static java.util.Arrays.asList; +import static org.assertj.core.api.Assertions.assertThat; import static org.testng.Assert.assertEquals; public class TestMongoSession { - private static final MongoColumnHandle COL1 = new MongoColumnHandle("col1", BIGINT, false, Optional.empty()); - private static final MongoColumnHandle COL2 = new MongoColumnHandle("col2", createUnboundedVarcharType(), false, Optional.empty()); - private static final MongoColumnHandle COL3 = new MongoColumnHandle("col3", createUnboundedVarcharType(), false, Optional.empty()); - private static final MongoColumnHandle COL4 = new MongoColumnHandle("col4", BOOLEAN, false, Optional.empty()); + private static final MongoColumnHandle COL1 = createColumnHandle("col1", BIGINT); + private static final MongoColumnHandle COL2 = createColumnHandle("col2", createUnboundedVarcharType()); + private static final MongoColumnHandle COL3 = createColumnHandle("col3", createUnboundedVarcharType()); + private static final MongoColumnHandle COL4 = createColumnHandle("col4", BOOLEAN); + private static final MongoColumnHandle COL5 = createColumnHandle("col5", BIGINT); + private static final MongoColumnHandle COL6 = createColumnHandle("grandparent", createUnboundedVarcharType(), "parent", "col6"); @Test public void testBuildQuery() @@ -52,8 +58,8 @@ public void testBuildQuery() Document query = MongoSession.buildQuery(tupleDomain); Document expected = new Document() - .append(COL1.getName(), new Document().append("$gt", 100L).append("$lte", 200L)) - .append(COL2.getName(), new Document("$eq", "a value")); + .append(COL1.getBaseName(), new Document().append("$gt", 100L).append("$lte", 200L)) + .append(COL2.getBaseName(), new Document("$eq", "a value")); assertEquals(query, expected); } @@ -66,8 +72,8 @@ public void testBuildQueryStringType() Document query = MongoSession.buildQuery(tupleDomain); Document expected = new Document() - .append(COL3.getName(), new Document().append("$gt", "hello").append("$lte", "world")) - .append(COL2.getName(), new Document("$gte", "a value")); + .append(COL3.getBaseName(), new Document().append("$gt", "hello").append("$lte", "world")) + .append(COL2.getBaseName(), new Document("$gte", "a value")); assertEquals(query, expected); } @@ -78,7 +84,7 @@ public void testBuildQueryIn() COL2, Domain.create(ValueSet.ofRanges(equal(createUnboundedVarcharType(), utf8Slice("hello")), equal(createUnboundedVarcharType(), utf8Slice("world"))), false))); Document query = MongoSession.buildQuery(tupleDomain); - Document expected = new Document(COL2.getName(), new Document("$in", ImmutableList.of("hello", "world"))); + Document expected = new Document(COL2.getBaseName(), new Document("$in", ImmutableList.of("hello", "world"))); assertEquals(query, expected); } @@ -90,8 +96,8 @@ public void testBuildQueryOr() Document query = MongoSession.buildQuery(tupleDomain); Document expected = new Document("$or", asList( - new Document(COL1.getName(), new Document("$lt", 100L)), - new Document(COL1.getName(), new Document("$gt", 200L)))); + new Document(COL1.getBaseName(), new Document("$lt", 100L)), + new Document(COL1.getBaseName(), new Document("$gt", 200L)))); assertEquals(query, expected); } @@ -103,8 +109,8 @@ public void testBuildQueryNull() Document query = MongoSession.buildQuery(tupleDomain); Document expected = new Document("$or", asList( - new Document(COL1.getName(), new Document("$gt", 200L)), - new Document(COL1.getName(), new Document("$eq", null)))); + new Document(COL1.getBaseName(), new Document("$gt", 200L)), + new Document(COL1.getBaseName(), new Document("$eq", null)))); assertEquals(query, expected); } @@ -114,7 +120,94 @@ public void testBooleanPredicatePushdown() TupleDomain tupleDomain = TupleDomain.withColumnDomains(ImmutableMap.of(COL4, Domain.singleValue(BOOLEAN, true))); Document query = MongoSession.buildQuery(tupleDomain); - Document expected = new Document().append(COL4.getName(), new Document("$eq", true)); + Document expected = new Document().append(COL4.getBaseName(), new Document("$eq", true)); assertEquals(query, expected); } + + @Test + public void testBuildQueryNestedField() + { + TupleDomain tupleDomain = TupleDomain.withColumnDomains(ImmutableMap.of( + COL5, Domain.create(ValueSet.ofRanges(greaterThan(BIGINT, 200L)), true), + COL6, Domain.singleValue(createUnboundedVarcharType(), utf8Slice("a value")))); + + Document query = MongoSession.buildQuery(tupleDomain); + Document expected = new Document() + .append("$or", asList( + new Document(COL5.getQualifiedName(), new Document("$gt", 200L)), + new Document(COL5.getQualifiedName(), new Document("$eq", null)))) + .append(COL6.getQualifiedName(), new Document("$eq", "a value")); + assertEquals(query, expected); + } + + @Test + public void testProjectSufficientColumns() + { + MongoColumnHandle col1 = createColumnHandle("x", BIGINT, "a", "b"); + MongoColumnHandle col2 = createColumnHandle("x", BIGINT, "b"); + MongoColumnHandle col3 = createColumnHandle("x", BIGINT, "c"); + MongoColumnHandle col4 = createColumnHandle("x", BIGINT); + + List output = projectSufficientColumns(ImmutableList + .of(col1, col2, col4)); + assertThat(output) + .containsExactly(col4) + .hasSize(1); + + output = projectSufficientColumns(ImmutableList.of(col4, col2, col1)); + assertThat(output) + .containsExactly(col4) + .hasSize(1); + + output = projectSufficientColumns(ImmutableList.of(col2, col1, col4)); + assertThat(output) + .containsExactly(col4) + .hasSize(1); + + output = projectSufficientColumns(ImmutableList.of(col2, col3)); + assertThat(output) + .containsExactly(col2, col3) + .hasSize(2); + + MongoColumnHandle col5 = createColumnHandle("x", BIGINT, "a", "b", "c"); + MongoColumnHandle col6 = createColumnHandle("x", BIGINT, "a", "c", "b"); + MongoColumnHandle col7 = createColumnHandle("x", BIGINT, "c", "a", "b"); + MongoColumnHandle col8 = createColumnHandle("x", BIGINT, "b", "a"); + MongoColumnHandle col9 = createColumnHandle("x", BIGINT); + + output = projectSufficientColumns(ImmutableList + .of(col5, col6)); + assertThat(output) + .containsExactly(col5, col6) + .hasSize(2); + + output = projectSufficientColumns(ImmutableList + .of(col6, col7)); + assertThat(output) + .containsExactly(col6, col7) + .hasSize(2); + + output = projectSufficientColumns(ImmutableList + .of(col5, col8)); + assertThat(output) + .containsExactly(col8, col5) + .hasSize(2); + + output = projectSufficientColumns(ImmutableList + .of(col5, col6, col7, col8, col9)); + assertThat(output) + .containsExactly(col9) + .hasSize(1); + } + + private static MongoColumnHandle createColumnHandle(String baseName, Type type, String... dereferenceNames) + { + return new MongoColumnHandle( + baseName, + ImmutableList.copyOf(dereferenceNames), + type, + false, + false, + Optional.empty()); + } } diff --git a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoTableHandle.java b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoTableHandle.java index 60042f6042c5..0e124e227f23 100644 --- a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoTableHandle.java +++ b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoTableHandle.java @@ -13,17 +13,40 @@ */ package io.trino.plugin.mongodb; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import io.airlift.json.JsonCodec; +import io.airlift.json.JsonCodecFactory; +import io.airlift.json.ObjectMapperProvider; import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.type.RowType; +import io.trino.spi.type.Type; +import io.trino.type.TypeDeserializer; +import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; import java.util.Optional; +import java.util.OptionalInt; +import java.util.Set; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER; import static org.testng.Assert.assertEquals; public class TestMongoTableHandle { - private final JsonCodec codec = JsonCodec.jsonCodec(MongoTableHandle.class); + private JsonCodec codec; + + @BeforeClass + public void init() + { + ObjectMapperProvider objectMapperProvider = new ObjectMapperProvider(); + objectMapperProvider.setJsonDeserializers(ImmutableMap.of(Type.class, new TypeDeserializer(TESTING_TYPE_MANAGER))); + codec = new JsonCodecFactory(objectMapperProvider).jsonCodec(MongoTableHandle.class); + } @Test public void testRoundTripWithoutQuery() @@ -76,4 +99,35 @@ public void testRoundTripWithQueryHavingHelperFunction() assertEquals(actual.getSchemaTableName(), expected.getSchemaTableName()); } + + @Test + public void testRoundTripWithProjectedColumns() + { + SchemaTableName schemaTableName = new SchemaTableName("schema", "table"); + RemoteTableName remoteTableName = new RemoteTableName("Schema", "Table"); + Set projectedColumns = ImmutableSet.of( + new MongoColumnHandle("id", ImmutableList.of(), INTEGER, false, false, Optional.empty()), + new MongoColumnHandle("address", ImmutableList.of("street"), VARCHAR, false, false, Optional.empty()), + new MongoColumnHandle( + "user", + ImmutableList.of(), + RowType.from(ImmutableList.of(new RowType.Field(Optional.of("first"), VARCHAR), new RowType.Field(Optional.of("last"), VARCHAR))), + false, + false, + Optional.empty()), + new MongoColumnHandle("creator", ImmutableList.of("databasename"), VARCHAR, false, true, Optional.empty())); + + MongoTableHandle expected = new MongoTableHandle( + schemaTableName, + remoteTableName, + Optional.empty(), + TupleDomain.all(), + projectedColumns, + OptionalInt.empty()); + + String json = codec.toJson(expected); + MongoTableHandle actual = codec.fromJson(json); + + assertEquals(actual, expected); + } }