diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java index ad19c616ff29a..203cc0e4ae05f 100644 --- a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java +++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java @@ -111,6 +111,7 @@ private static ArrowFlightSqlClientHandler createNewClientHandler( .withCallOptions(config.toCallOption()) .withRetainCookies(config.retainCookies()) .withRetainAuth(config.retainAuth()) + .withCatalog(config.getCatalog()) .build(); } catch (final SQLException e) { try { @@ -169,6 +170,7 @@ public Properties getClientInfo() { @Override public void close() throws SQLException { + clientHandler.close(); if (executorService != null) { executorService.shutdown(); } diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java index edd888ef4df81..ea5bbddc44ef9 100644 --- a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java +++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java @@ -26,10 +26,12 @@ import java.util.Collection; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; import org.apache.arrow.driver.jdbc.client.utils.ClientAuthenticationUtils; import org.apache.arrow.flight.CallOption; +import org.apache.arrow.flight.CloseSessionRequest; import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightClientMiddleware; import org.apache.arrow.flight.FlightEndpoint; @@ -38,6 +40,10 @@ import org.apache.arrow.flight.FlightStatusCode; import org.apache.arrow.flight.Location; import org.apache.arrow.flight.LocationSchemes; +import org.apache.arrow.flight.SessionOptionValue; +import org.apache.arrow.flight.SessionOptionValueFactory; +import org.apache.arrow.flight.SetSessionOptionsRequest; +import org.apache.arrow.flight.SetSessionOptionsResult; import org.apache.arrow.flight.auth2.BearerCredentialWriter; import org.apache.arrow.flight.auth2.ClientBearerHeaderHandler; import org.apache.arrow.flight.auth2.ClientIncomingAuthHeaderMiddleware; @@ -56,23 +62,31 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.base.Strings; +import com.google.common.collect.ImmutableMap; + /** * A {@link FlightSqlClient} handler. */ public final class ArrowFlightSqlClientHandler implements AutoCloseable { private static final Logger LOGGER = LoggerFactory.getLogger(ArrowFlightSqlClientHandler.class); + // JDBC connection string query parameter + private static final String CATALOG = "catalog"; private final FlightSqlClient sqlClient; private final Set options = new HashSet<>(); private final Builder builder; + private final String catalog; ArrowFlightSqlClientHandler(final FlightSqlClient sqlClient, final Builder builder, - final Collection credentialOptions) { + final Collection credentialOptions, + final String catalog) { this.options.addAll(builder.options); this.options.addAll(credentialOptions); this.sqlClient = Preconditions.checkNotNull(sqlClient); this.builder = builder; + this.catalog = catalog; } /** @@ -84,8 +98,9 @@ public final class ArrowFlightSqlClientHandler implements AutoCloseable { */ public static ArrowFlightSqlClientHandler createNewHandler(final FlightClient client, final Builder builder, - final Collection options) { - return new ArrowFlightSqlClientHandler(new FlightSqlClient(client), builder, options); + final Collection options, + final String catalog) { + return new ArrowFlightSqlClientHandler(new FlightSqlClient(client), builder, options, catalog); } /** @@ -194,6 +209,9 @@ public FlightInfo getInfo(final String query) { @Override public void close() throws SQLException { + if (hasCatalog()) { + sqlClient.closeSession(new CloseSessionRequest(), getOptions()); + } try { AutoCloseables.close(sqlClient); } catch (final Exception e) { @@ -201,6 +219,10 @@ public void close() throws SQLException { } } + private boolean hasCatalog() { + return !Strings.isNullOrEmpty(catalog); + } + /** * A prepared statement handler. */ @@ -254,6 +276,21 @@ public interface PreparedStatement extends AutoCloseable { * @return a new prepared statement. */ public PreparedStatement prepare(final String query) { + if (hasCatalog()) { + final SetSessionOptionsRequest setSessionOptionRequest = + new SetSessionOptionsRequest(ImmutableMap.builder() + .put(CATALOG, SessionOptionValueFactory.makeSessionOptionValue(catalog)) + .build()); + final SetSessionOptionsResult result = sqlClient.setSessionOptions(setSessionOptionRequest, getOptions()); + if (result.hasErrors()) { + Map errors = result.getErrors(); + for (Map.Entry error : errors.entrySet()) { + LOGGER.warn(error.toString()); + } + throw new RuntimeException(String.format("Cannot set session option for catalog = %s", catalog)); + } + } + final FlightSqlClient.PreparedStatement preparedStatement = sqlClient.prepare(query, getOptions()); return new PreparedStatement() { @@ -495,6 +532,9 @@ public static final class Builder { @VisibleForTesting boolean retainAuth = true; + @VisibleForTesting + String catalog; + // These two middleware are for internal use within build() and should not be exposed by builder APIs. // Note that these middleware may not necessarily be registered. @VisibleForTesting @@ -530,6 +570,7 @@ public Builder() { this.clientCertificatePath = original.clientCertificatePath; this.clientKeyPath = original.clientKeyPath; this.allocator = original.allocator; + this.catalog = original.catalog; if (original.retainCookies) { this.cookieFactory = original.cookieFactory; @@ -761,6 +802,16 @@ public Builder withCallOptions(final Collection options) { return this; } + /** + * Sets the catalog for this handler. + * @param catalog the catalog + * @return this instance. + */ + public Builder withCatalog(final String catalog) { + this.catalog = catalog; + return this; + } + /** * Builds a new {@link ArrowFlightSqlClientHandler} from the provided fields. * @@ -833,7 +884,7 @@ public ArrowFlightSqlClientHandler build() throws SQLException { client, new CredentialCallOption(new BearerCredentialWriter(token)), options.toArray( new CallOption[0]))); } - return ArrowFlightSqlClientHandler.createNewHandler(client, this, credentialOptions); + return ArrowFlightSqlClientHandler.createNewHandler(client, this, credentialOptions, catalog); } catch (final IllegalArgumentException | GeneralSecurityException | IOException | FlightRuntimeException e) { final SQLException originalException = new SQLException(e); diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java index e95cf00bc7a21..1cb4bec54f146 100644 --- a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java +++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java @@ -160,6 +160,14 @@ public boolean retainAuth() { return ArrowFlightConnectionProperty.RETAIN_AUTH.getBoolean(properties); } + /** + * The catalog to which a connection is made. + * @return the catalog. + */ + public String getCatalog() { + return ArrowFlightConnectionProperty.CATALOG.getString(properties); + } + /** * Gets the {@link CallOption}s from this {@link ConnectionConfig}. * @@ -210,7 +218,8 @@ public enum ArrowFlightConnectionProperty implements ConnectionProperty { THREAD_POOL_SIZE("threadPoolSize", 1, Type.NUMBER, false), TOKEN("token", null, Type.STRING, false), RETAIN_COOKIES("retainCookies", true, Type.BOOLEAN, false), - RETAIN_AUTH("retainAuth", true, Type.BOOLEAN, false); + RETAIN_AUTH("retainAuth", true, Type.BOOLEAN, false), + CATALOG("catalog", null, Type.STRING, false); private final String camelName; private final Object defaultValue; diff --git a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImplTest.java b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImplTest.java index 4fb07428af4ef..a7ce5c8378472 100644 --- a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImplTest.java +++ b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImplTest.java @@ -19,6 +19,7 @@ import static java.lang.Runtime.getRuntime; import static java.util.Arrays.asList; +import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.CATALOG; import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.HOST; import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.PASSWORD; import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.PORT; @@ -91,6 +92,8 @@ public static List provideParameters() { {THREAD_POOL_SIZE, RANDOM.nextInt(getRuntime().availableProcessors()), (Function) ArrowFlightConnectionConfigImpl::threadPoolSize}, + {CATALOG, "catalog", + (Function) ArrowFlightConnectionConfigImpl::getCatalog}, }); } }