Skip to content

Commit

Permalink
Ensure module instances are not shared by multiple JDBC catalogs
Browse files Browse the repository at this point in the history
When a JdbcPlugin is instantiated, creating an instance of a module at
that time causes the instance to be shared by all connectors provided by
the plugin. This is problematic when the module extends
AbstractConfigurationAwareModule, as it holds a reference to the
ConfigurationFactory, which is set dynamically during bootstrap. If
catalogs are loaded concurrently, this can lead to situations where a
connector accesses the configuration of another connector.
  • Loading branch information
piotrrzysko authored and wendigo committed Nov 10, 2024
1 parent 569838d commit 115b8ea
Show file tree
Hide file tree
Showing 25 changed files with 126 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import io.trino.spi.type.TypeManager;

import java.util.Map;
import java.util.function.Supplier;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Strings.isNullOrEmpty;
Expand All @@ -36,13 +37,13 @@ public class JdbcConnectorFactory
implements ConnectorFactory
{
private final String name;
private final Module module;
private final Supplier<Module> module;

public JdbcConnectorFactory(String name, Module module)
public JdbcConnectorFactory(String name, Supplier<Module> module)
{
checkArgument(!isNullOrEmpty(name), "name is null or empty");
this.name = name;
this.module = module;
this.module = requireNonNull(module, "module is null");
}

@Override
Expand All @@ -55,7 +56,6 @@ public String getName()
public Connector create(String catalogName, Map<String, String> requiredConfig, ConnectorContext context)
{
requireNonNull(requiredConfig, "requiredConfig is null");
requireNonNull(module, "module is null");
checkStrictSpiVersionMatch(context, this);

Bootstrap app = new Bootstrap(
Expand All @@ -65,7 +65,7 @@ public Connector create(String catalogName, Map<String, String> requiredConfig,
binder -> binder.bind(OpenTelemetry.class).toInstance(context.getOpenTelemetry()),
binder -> binder.bind(CatalogName.class).toInstance(new CatalogName(catalogName)),
new JdbcModule(),
module);
module.get());

Injector injector = app
.doNotInitializeLogging()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import io.trino.spi.Plugin;
import io.trino.spi.connector.ConnectorFactory;

import java.util.function.Supplier;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Strings.isNullOrEmpty;
import static io.airlift.configuration.ConfigurationAwareModule.combine;
Expand All @@ -28,9 +30,9 @@ public class JdbcPlugin
implements Plugin
{
private final String name;
private final Module module;
private final Supplier<Module> module;

public JdbcPlugin(String name, Module module)
public JdbcPlugin(String name, Supplier<Module> module)
{
checkArgument(!isNullOrEmpty(name), "name is null or empty");
this.name = name;
Expand All @@ -42,9 +44,9 @@ public Iterable<ConnectorFactory> getConnectorFactories()
{
return ImmutableList.of(new JdbcConnectorFactory(
name,
combine(
() -> combine(
new CredentialProviderModule(),
new ExtraCredentialsBasedIdentityCacheMappingModule(),
module)));
module.get())));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ public static QueryRunner createH2QueryRunner(

createSchema(properties, "tpch");

queryRunner.installPlugin(new JdbcPlugin("base_jdbc", module));
queryRunner.installPlugin(new JdbcPlugin("base_jdbc", () -> module));
queryRunner.createCatalog("jdbc", "base_jdbc", properties);

copyTpchTables(queryRunner, "tpch", TINY_SCHEMA_NAME, tables);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,32 @@
*/
package io.trino.plugin.jdbc;

import com.google.common.base.Supplier;
import com.google.common.collect.ImmutableMap;
import com.google.inject.Binder;
import com.google.inject.Inject;
import com.google.inject.Module;
import com.google.inject.Scopes;
import io.airlift.configuration.AbstractConfigurationAwareModule;
import io.trino.spi.Plugin;
import io.trino.spi.catalog.CatalogName;
import io.trino.spi.connector.Connector;
import io.trino.spi.connector.ConnectorFactory;
import io.trino.testing.TestingConnectorContext;
import org.junit.jupiter.api.RepeatedTest;
import org.junit.jupiter.api.Test;

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicReference;

import static com.google.common.collect.Iterables.getOnlyElement;
import static io.airlift.configuration.ConditionalModule.conditionalModule;
import static io.trino.plugin.base.mapping.MappingConfig.CASE_INSENSITIVE_NAME_MATCHING;
import static io.trino.plugin.base.mapping.RuleBasedIdentifierMappingUtils.createRuleBasedIdentifierMappingFile;
import static io.trino.plugin.jdbc.TestingH2JdbcModule.createH2ConnectionUrl;
import static org.assertj.core.api.Assertions.assertThatCode;

public class TestJdbcPlugin
{
Expand All @@ -47,9 +64,84 @@ public void testRuleBasedIdentifierCanBeUsedTogetherWithCacheBased()
.shutdown();
}

@RepeatedTest(100)
void testConfigurationDoesNotLeakBetweenCatalogs()
{
TestingJdbcPlugin plugin = new TestingJdbcPlugin("test_jdbc", TestingJdbcModule::new);
ConnectorFactory connectorFactory = getOnlyElement(plugin.getConnectorFactories());

try (ExecutorService executor = Executors.newFixedThreadPool(2)) {
Future<Connector> pushDownEnabledFuture = executor.submit(() -> connectorFactory.create(
TestingJdbcModule.CATALOG_WITH_PUSH_DOWN_ENABLED,
ImmutableMap.of("connection-url", createH2ConnectionUrl(), "join-pushdown.enabled", "true"),
new TestingConnectorContext()));
Future<Connector> pushDownDisabledFuture = executor.submit(() -> connectorFactory.create(
TestingJdbcModule.CATALOG_WITH_PUSH_DOWN_DISABLED,
ImmutableMap.of("connection-url", createH2ConnectionUrl(), "join-pushdown.enabled", "false"),
new TestingConnectorContext()));

AtomicReference<Connector> catalogWithPushDownEnabled = new AtomicReference<>();
AtomicReference<Connector> catalogWithPushDownDisabled = new AtomicReference<>();
assertThatCode(() -> {
catalogWithPushDownEnabled.set(pushDownEnabledFuture.get());
catalogWithPushDownDisabled.set(pushDownDisabledFuture.get());
}).doesNotThrowAnyException();

catalogWithPushDownEnabled.get().shutdown();
catalogWithPushDownDisabled.get().shutdown();
}
}

private static class TestingJdbcPlugin
extends JdbcPlugin
{
public TestingJdbcPlugin(String name, Supplier<Module> module)
{
super(name, module);
}
}

private static class TestingJdbcModule
extends AbstractConfigurationAwareModule
{
public static final String CATALOG_WITH_PUSH_DOWN_ENABLED = "catalogWithPushDownEnabled";
public static final String CATALOG_WITH_PUSH_DOWN_DISABLED = "catalogWithPushDownDisabled";

@Override
protected void setup(Binder binder)
{
install(conditionalModule(
JdbcMetadataConfig.class,
JdbcMetadataConfig::isJoinPushdownEnabled,
new ModuleCheckingThatPushDownCanBeEnabled()));
install(new TestingH2JdbcModule());
}
}

private static class ModuleCheckingThatPushDownCanBeEnabled
implements Module
{
@Override
public void configure(Binder binder)
{
binder.bind(PushDownCanBeEnabledChecker.class).in(Scopes.SINGLETON);
}
}

private static class PushDownCanBeEnabledChecker
{
@Inject
public PushDownCanBeEnabledChecker(CatalogName catalogName)
{
if (!TestingJdbcModule.CATALOG_WITH_PUSH_DOWN_ENABLED.equals(catalogName.toString())) {
throw new RuntimeException("Catalog '%s' should not have push-down enabled".formatted(catalogName));
}
}
}

private static ConnectorFactory getConnectorFactory()
{
Plugin plugin = new JdbcPlugin("jdbc", new TestingH2JdbcModule());
Plugin plugin = new JdbcPlugin("jdbc", TestingH2JdbcModule::new);
return getOnlyElement(plugin.getConnectorFactories());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public class TestJmxStats
public void testJmxStatsExposure()
throws Exception
{
Plugin plugin = new JdbcPlugin("base_jdbc", new TestingH2JdbcModule());
Plugin plugin = new JdbcPlugin("base_jdbc", TestingH2JdbcModule::new);
ConnectorFactory factory = getOnlyElement(plugin.getConnectorFactories());
factory.create(
"test",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ public class ClickHousePlugin
{
public ClickHousePlugin()
{
super("clickhouse", new ClickHouseClientModule());
super("clickhouse", ClickHouseClientModule::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ public class DruidJdbcPlugin
{
public DruidJdbcPlugin()
{
super("druid", new DruidJdbcClientModule());
super("druid", DruidJdbcClientModule::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ public class ExamplePlugin
{
public ExamplePlugin()
{
super("example_jdbc", new ExampleClientModule());
super("example_jdbc", ExampleClientModule::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ public class ExasolPlugin
{
public ExasolPlugin()
{
super("exasol", new ExasolClientModule());
super("exasol", ExasolClientModule::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ public class IgnitePlugin
{
public IgnitePlugin()
{
super("ignite", new IgniteClientModule());
super("ignite", IgniteClientModule::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ public class MariaDbPlugin
{
public MariaDbPlugin()
{
super("mariadb", new MariaDbClientModule());
super("mariadb", MariaDbClientModule::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ public class MySqlPlugin
{
public MySqlPlugin()
{
super("mysql", new MySqlClientModule());
super("mysql", MySqlClientModule::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ public class OraclePlugin
{
public OraclePlugin()
{
super("oracle", new OracleClientModule());
super("oracle", OracleClientModule::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ public class PostgreSqlPlugin
{
public PostgreSqlPlugin()
{
super("postgresql", combine(new PostgreSqlClientModule(), new PostgreSqlConnectionFactoryModule()));
super("postgresql", () -> combine(new PostgreSqlClientModule(), new PostgreSqlConnectionFactoryModule()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ protected QueryRunner createQueryRunner()
.setAdditionalSetup(runner -> {
runner.installPlugin(new JdbcPlugin(
"counting_postgresql",
combine(new PostgreSqlClientModule(), new TestingPostgreSqlModule(connectionFactory))));
() -> combine(new PostgreSqlClientModule(), new TestingPostgreSqlModule(connectionFactory))));
runner.createCatalog("counting_postgresql", "counting_postgresql", ImmutableMap.of(
"connection-url", postgreSqlServer.getJdbcUrl(),
"connection-user", postgreSqlServer.getUser(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ protected QueryRunner createQueryRunner()
.setAdditionalSetup(runner -> {
runner.installPlugin(new JdbcPlugin(
"counting_postgresql",
combine(new PostgreSqlClientModule(), new TestingPostgreSqlModule(connectionFactory))));
() -> combine(new PostgreSqlClientModule(), new TestingPostgreSqlModule(connectionFactory))));
runner.createCatalog("counting_postgresql", "counting_postgresql", ImmutableMap.of(
"connection-url", postgreSqlServer.getJdbcUrl(),
"connection-user", postgreSqlServer.getUser(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ public class RedshiftPlugin
{
public RedshiftPlugin()
{
super("redshift", new RedshiftClientModule());
super("redshift", RedshiftClientModule::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ public class SingleStorePlugin
{
public SingleStorePlugin()
{
super("singlestore", new SingleStoreClientModule());
super("singlestore", SingleStoreClientModule::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ public class SnowflakePlugin
{
public SnowflakePlugin()
{
super("snowflake", new SnowflakeClientModule());
super("snowflake", SnowflakeClientModule::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ public class SqlServerPlugin
{
public SqlServerPlugin()
{
super("sqlserver", combine(new SqlServerClientModule(), new SqlServerConnectionFactoryModule()));
super("sqlserver", () -> combine(new SqlServerClientModule(), new SqlServerConnectionFactoryModule()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ protected QueryRunner createQueryRunner()
.setAdditionalSetup(runner -> {
runner.installPlugin(new JdbcPlugin(
"counting_sqlserver",
combine(new SqlServerClientModule(), new TestingSqlServerModule(connectionFactory))));
() -> combine(new SqlServerClientModule(), new TestingSqlServerModule(connectionFactory))));
runner.createCatalog("counting_sqlserver", "counting_sqlserver", ImmutableMap.of(
"connection-url", sqlServer.getJdbcUrl(),
"connection-user", sqlServer.getUsername(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ protected QueryRunner createQueryRunner()
.setAdditionalSetup(runner -> {
runner.installPlugin(new JdbcPlugin(
"counting_sqlserver",
combine(new SqlServerClientModule(), new TestingSqlServerModule(connectionFactory))));
() -> combine(new SqlServerClientModule(), new TestingSqlServerModule(connectionFactory))));
runner.createCatalog("counting_sqlserver", "counting_sqlserver", ImmutableMap.of(
"connection-url", sqlServer.getJdbcUrl(),
"connection-user", sqlServer.getUsername(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ public class VerticaPlugin
{
public VerticaPlugin()
{
super("vertica", new VerticaClientModule());
super("vertica", VerticaClientModule::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ public ScalarFunctionImplementation getScalarFunctionImplementation(FunctionId f
}))
.build()));
queryRunner.createCatalog("mock", "mock");
queryRunner.installPlugin(new JdbcPlugin("base_jdbc", new TestingH2JdbcModule()));
queryRunner.installPlugin(new JdbcPlugin("base_jdbc", TestingH2JdbcModule::new));
queryRunner.createCatalog("jdbc", "base_jdbc", TestingH2JdbcModule.createProperties());
for (String tableName : ImmutableList.of("orders", "nation", "region", "lineitem")) {
queryRunner.execute(format("CREATE TABLE %1$s AS SELECT * FROM tpch.tiny.%1$s WITH NO DATA", tableName));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ protected void configureCatalog(QueryRunner queryRunner)
queryRunner.installPlugin(new TpchPlugin());
queryRunner.createCatalog("tpch", "tpch", ImmutableMap.of());

queryRunner.installPlugin(new JdbcPlugin("base_jdbc", new TestingH2JdbcModule()));
queryRunner.installPlugin(new JdbcPlugin("base_jdbc", TestingH2JdbcModule::new));
Map<String, String> jdbcConfigurationProperties = TestingH2JdbcModule.createProperties();
queryRunner.createCatalog("jdbc", "base_jdbc", jdbcConfigurationProperties);

Expand Down

0 comments on commit 115b8ea

Please sign in to comment.