Skip to content

Commit

Permalink
Implement dereference pushdown for MongoDB connector
Browse files Browse the repository at this point in the history
Co-authored-by: praveenkrishna <[email protected]>
Co-authored-by: Mateusz "Serafin" Gajewski <[email protected]>
  • Loading branch information
Praveen2112 and wendigo committed Jul 4, 2023
1 parent fb2b553 commit dd4bcb0
Show file tree
Hide file tree
Showing 20 changed files with 1,597 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Predicate;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;
Expand All @@ -33,31 +34,38 @@ public final class ApplyProjectionUtil
private ApplyProjectionUtil() {}

public static List<ConnectorExpression> extractSupportedProjectedColumns(ConnectorExpression expression)
{
return extractSupportedProjectedColumns(expression, connectorExpression -> true);
}

public static List<ConnectorExpression> extractSupportedProjectedColumns(ConnectorExpression expression, Predicate<ConnectorExpression> expressionPredicate)
{
requireNonNull(expression, "expression is null");
ImmutableList.Builder<ConnectorExpression> supportedSubExpressions = ImmutableList.builder();
fillSupportedProjectedColumns(expression, supportedSubExpressions);
fillSupportedProjectedColumns(expression, supportedSubExpressions, expressionPredicate);
return supportedSubExpressions.build();
}

private static void fillSupportedProjectedColumns(ConnectorExpression expression, ImmutableList.Builder<ConnectorExpression> supportedSubExpressions)
private static void fillSupportedProjectedColumns(ConnectorExpression expression, ImmutableList.Builder<ConnectorExpression> supportedSubExpressions, Predicate<ConnectorExpression> expressionPredicate)
{
if (isPushDownSupported(expression)) {
if (isPushdownSupported(expression, expressionPredicate)) {
supportedSubExpressions.add(expression);
return;
}

// If the whole expression is not supported, look for a partially supported projection
for (ConnectorExpression child : expression.getChildren()) {
fillSupportedProjectedColumns(child, supportedSubExpressions);
fillSupportedProjectedColumns(child, supportedSubExpressions, expressionPredicate);
}
}

@VisibleForTesting
static boolean isPushDownSupported(ConnectorExpression expression)
static boolean isPushdownSupported(ConnectorExpression expression, Predicate<ConnectorExpression> expressionPredicate)
{
return expression instanceof Variable ||
(expression instanceof FieldDereference fieldDereference && isPushDownSupported(fieldDereference.getTarget()));
return expressionPredicate.test(expression)
&& (expression instanceof Variable ||
(expression instanceof FieldDereference fieldDereference
&& isPushdownSupported(fieldDereference.getTarget(), expressionPredicate)));
}

public static ProjectedColumnRepresentation createProjectedColumnRepresentation(ConnectorExpression expression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
import io.trino.spi.expression.Constant;
import io.trino.spi.expression.FieldDereference;
import io.trino.spi.expression.Variable;
import io.trino.spi.type.RowType;
import org.testng.annotations.Test;

import static io.trino.plugin.base.projection.ApplyProjectionUtil.extractSupportedProjectedColumns;
import static io.trino.plugin.base.projection.ApplyProjectionUtil.isPushDownSupported;
import static io.trino.plugin.base.projection.ApplyProjectionUtil.isPushdownSupported;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.spi.type.RowType.field;
import static io.trino.spi.type.RowType.rowType;
Expand All @@ -32,6 +33,8 @@
public class TestApplyProjectionUtil
{
private static final ConnectorExpression ROW_OF_ROW_VARIABLE = new Variable("a", rowType(field("b", rowType(field("c", INTEGER)))));
private static final ConnectorExpression LEAF_DOTTED_ROW_OF_ROW_VARIABLE = new Variable("a", rowType(field("b", rowType(field("c.x", INTEGER)))));
private static final ConnectorExpression MID_DOTTED_ROW_OF_ROW_VARIABLE = new Variable("a", rowType(field("b.x", rowType(field("c", INTEGER)))));

private static final ConnectorExpression ONE_LEVEL_DEREFERENCE = new FieldDereference(
rowType(field("c", INTEGER)),
Expand All @@ -43,16 +46,46 @@ public class TestApplyProjectionUtil
ONE_LEVEL_DEREFERENCE,
0);

private static final ConnectorExpression LEAF_DOTTED_ONE_LEVEL_DEREFERENCE = new FieldDereference(
rowType(field("c.x", INTEGER)),
LEAF_DOTTED_ROW_OF_ROW_VARIABLE,
0);

private static final ConnectorExpression LEAF_DOTTED_TWO_LEVEL_DEREFERENCE = new FieldDereference(
INTEGER,
LEAF_DOTTED_ONE_LEVEL_DEREFERENCE,
0);

private static final ConnectorExpression MID_DOTTED_ONE_LEVEL_DEREFERENCE = new FieldDereference(
rowType(field("c.x", INTEGER)),
MID_DOTTED_ROW_OF_ROW_VARIABLE,
0);

private static final ConnectorExpression MID_DOTTED_TWO_LEVEL_DEREFERENCE = new FieldDereference(
INTEGER,
MID_DOTTED_ONE_LEVEL_DEREFERENCE,
0);

private static final ConnectorExpression INT_VARIABLE = new Variable("a", INTEGER);
private static final ConnectorExpression CONSTANT = new Constant(5, INTEGER);

@Test
public void testIsProjectionSupported()
{
assertTrue(isPushDownSupported(ONE_LEVEL_DEREFERENCE));
assertTrue(isPushDownSupported(TWO_LEVEL_DEREFERENCE));
assertTrue(isPushDownSupported(INT_VARIABLE));
assertFalse(isPushDownSupported(CONSTANT));
assertTrue(isPushdownSupported(ONE_LEVEL_DEREFERENCE, connectorExpression -> true));
assertTrue(isPushdownSupported(TWO_LEVEL_DEREFERENCE, connectorExpression -> true));
assertTrue(isPushdownSupported(INT_VARIABLE, connectorExpression -> true));
assertFalse(isPushdownSupported(CONSTANT, connectorExpression -> true));

assertFalse(isPushdownSupported(ONE_LEVEL_DEREFERENCE, connectorExpression -> false));
assertFalse(isPushdownSupported(TWO_LEVEL_DEREFERENCE, connectorExpression -> false));
assertFalse(isPushdownSupported(INT_VARIABLE, connectorExpression -> false));
assertFalse(isPushdownSupported(CONSTANT, connectorExpression -> false));

assertTrue(isPushdownSupported(LEAF_DOTTED_ONE_LEVEL_DEREFERENCE, this::isSupportedForPushDown));
assertFalse(isPushdownSupported(LEAF_DOTTED_TWO_LEVEL_DEREFERENCE, this::isSupportedForPushDown));
assertFalse(isPushdownSupported(MID_DOTTED_ONE_LEVEL_DEREFERENCE, this::isSupportedForPushDown));
assertFalse(isPushdownSupported(MID_DOTTED_TWO_LEVEL_DEREFERENCE, this::isSupportedForPushDown));
}

@Test
Expand All @@ -62,5 +95,32 @@ public void testExtractSupportedProjectionColumns()
assertEquals(extractSupportedProjectedColumns(TWO_LEVEL_DEREFERENCE), ImmutableList.of(TWO_LEVEL_DEREFERENCE));
assertEquals(extractSupportedProjectedColumns(INT_VARIABLE), ImmutableList.of(INT_VARIABLE));
assertEquals(extractSupportedProjectedColumns(CONSTANT), ImmutableList.of());

assertEquals(extractSupportedProjectedColumns(ONE_LEVEL_DEREFERENCE, connectorExpression -> false), ImmutableList.of());
assertEquals(extractSupportedProjectedColumns(TWO_LEVEL_DEREFERENCE, connectorExpression -> false), ImmutableList.of());
assertEquals(extractSupportedProjectedColumns(INT_VARIABLE, connectorExpression -> false), ImmutableList.of());
assertEquals(extractSupportedProjectedColumns(CONSTANT, connectorExpression -> false), ImmutableList.of());

// Partial supported projection
assertEquals(extractSupportedProjectedColumns(LEAF_DOTTED_ONE_LEVEL_DEREFERENCE, this::isSupportedForPushDown), ImmutableList.of(LEAF_DOTTED_ONE_LEVEL_DEREFERENCE));
assertEquals(extractSupportedProjectedColumns(LEAF_DOTTED_TWO_LEVEL_DEREFERENCE, this::isSupportedForPushDown), ImmutableList.of(LEAF_DOTTED_ONE_LEVEL_DEREFERENCE));
assertEquals(extractSupportedProjectedColumns(MID_DOTTED_ONE_LEVEL_DEREFERENCE, this::isSupportedForPushDown), ImmutableList.of(MID_DOTTED_ROW_OF_ROW_VARIABLE));
assertEquals(extractSupportedProjectedColumns(MID_DOTTED_TWO_LEVEL_DEREFERENCE, this::isSupportedForPushDown), ImmutableList.of(MID_DOTTED_ROW_OF_ROW_VARIABLE));
}

/**
* This method is used to simulate the behavior when the field passed in the connectorExpression might not supported for pushdown.
*/
private boolean isSupportedForPushDown(ConnectorExpression connectorExpression)
{
if (connectorExpression instanceof FieldDereference fieldDereference) {
RowType rowType = (RowType) fieldDereference.getTarget().getType();
RowType.Field field = rowType.getFields().get(fieldDereference.getField());
String fieldName = field.getName().get();
if (fieldName.contains(".") || fieldName.contains("$")) {
return false;
}
}
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.trino.plugin.mongodb;

import io.airlift.configuration.Config;
import io.airlift.configuration.ConfigDescription;
import io.airlift.configuration.ConfigSecuritySensitive;
import io.airlift.configuration.DefunctConfig;
import io.airlift.configuration.LegacyConfig;
Expand Down Expand Up @@ -44,6 +45,7 @@ public class MongoClientConfig
private WriteConcernType writeConcern = WriteConcernType.ACKNOWLEDGED;
private String requiredReplicaSetName;
private String implicitRowFieldPrefix = "_pos";
private boolean projectionPushDownEnabled = true;

@NotNull
public String getSchemaCollection()
Expand Down Expand Up @@ -237,4 +239,17 @@ public MongoClientConfig setMaxConnectionIdleTime(int maxConnectionIdleTime)
this.maxConnectionIdleTime = maxConnectionIdleTime;
return this;
}

public boolean isProjectionPushdownEnabled()
{
return projectionPushDownEnabled;
}

@Config("mongodb.projection-pushdown-enabled")
@ConfigDescription("Read only required fields from a row type")
public MongoClientConfig setProjectionPushdownEnabled(boolean projectionPushDownEnabled)
{
this.projectionPushDownEnabled = projectionPushDownEnabled;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoClients;
import io.airlift.configuration.AbstractConfigurationAwareModule;
import io.trino.plugin.base.session.SessionPropertiesProvider;
import io.trino.plugin.mongodb.ptf.Query;
import io.trino.spi.function.table.ConnectorTableFunction;
import io.trino.spi.type.TypeManager;
Expand All @@ -44,6 +45,7 @@ public void setup(Binder binder)
binder.bind(MongoSplitManager.class).in(Scopes.SINGLETON);
binder.bind(MongoPageSourceProvider.class).in(Scopes.SINGLETON);
binder.bind(MongoPageSinkProvider.class).in(Scopes.SINGLETON);
newSetBinder(binder, SessionPropertiesProvider.class).addBinding().to(MongoSessionProperties.class).in(Scopes.SINGLETON);

configBinder(binder).bindConfig(MongoClientConfig.class);
newSetBinder(binder, MongoClientSettingConfigurator.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@
package io.trino.plugin.mongodb;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableList;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ColumnMetadata;
import io.trino.spi.type.Type;
import org.bson.Document;

import java.util.List;
import java.util.Objects;
import java.util.Optional;

Expand All @@ -28,28 +32,41 @@
public class MongoColumnHandle
implements ColumnHandle
{
private final String name;
private final String baseName;
private final List<String> dereferenceNames;
private final Type type;
private final boolean hidden;
// Represent if the field is inside a DBRef type
private final boolean dbRefField;
private final Optional<String> comment;

@JsonCreator
public MongoColumnHandle(
@JsonProperty("name") String name,
@JsonProperty("baseName") String baseName,
@JsonProperty("dereferenceNames") List<String> dereferenceNames,
@JsonProperty("columnType") Type type,
@JsonProperty("hidden") boolean hidden,
@JsonProperty("dbRefField") boolean dbRefField,
@JsonProperty("comment") Optional<String> comment)
{
this.name = requireNonNull(name, "name is null");
this.baseName = requireNonNull(baseName, "baseName is null");
this.dereferenceNames = ImmutableList.copyOf(requireNonNull(dereferenceNames, "dereferenceNames is null"));
this.type = requireNonNull(type, "type is null");
this.hidden = hidden;
this.dbRefField = dbRefField;
this.comment = requireNonNull(comment, "comment is null");
}

@JsonProperty
public String getName()
public String getBaseName()
{
return name;
return baseName;
}

@JsonProperty
public List<String> getDereferenceNames()
{
return dereferenceNames;
}

@JsonProperty("columnType")
Expand All @@ -64,6 +81,15 @@ public boolean isHidden()
return hidden;
}

/**
* This method may return a wrong value when row type use the same field names and types as dbref.
*/
@JsonProperty
public boolean isDbRefField()
{
return dbRefField;
}

@JsonProperty
public Optional<String> getComment()
{
Expand All @@ -73,25 +99,42 @@ public Optional<String> getComment()
public ColumnMetadata toColumnMetadata()
{
return ColumnMetadata.builder()
.setName(name)
.setName(getQualifiedName())
.setType(type)
.setHidden(hidden)
.setComment(comment)
.build();
}

@JsonIgnore
public String getQualifiedName()
{
return Joiner.on('.')
.join(ImmutableList.<String>builder()
.add(baseName)
.addAll(dereferenceNames)
.build());
}

@JsonIgnore
public boolean isBaseColumn()
{
return dereferenceNames.isEmpty();
}

public Document getDocument()
{
return new Document().append("name", name)
return new Document().append("name", getQualifiedName())
.append("type", type.getTypeSignature().toString())
.append("hidden", hidden)
.append("dbRefField", dbRefField)
.append("comment", comment.orElse(null));
}

@Override
public int hashCode()
{
return Objects.hash(name, type, hidden, comment);
return Objects.hash(baseName, dereferenceNames, type, hidden, dbRefField, comment);
}

@Override
Expand All @@ -104,15 +147,17 @@ public boolean equals(Object obj)
return false;
}
MongoColumnHandle other = (MongoColumnHandle) obj;
return Objects.equals(name, other.name) &&
return Objects.equals(baseName, other.baseName) &&
Objects.equals(dereferenceNames, other.dereferenceNames) &&
Objects.equals(type, other.type) &&
Objects.equals(hidden, other.hidden) &&
Objects.equals(dbRefField, other.dbRefField) &&
Objects.equals(comment, other.comment);
}

@Override
public String toString()
{
return name + ":" + type;
return getQualifiedName() + ":" + type;
}
}
Loading

0 comments on commit dd4bcb0

Please sign in to comment.