Skip to content

Commit

Permalink
apacheGH-41947: [Java] Support catalog in JDBC driver with session op…
Browse files Browse the repository at this point in the history
…tions
  • Loading branch information
stevelorddremio committed Jun 6, 2024
1 parent 7f0c407 commit f59d7c7
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ private static ArrowFlightSqlClientHandler createNewClientHandler(
.withCallOptions(config.toCallOption())
.withRetainCookies(config.retainCookies())
.withRetainAuth(config.retainAuth())
.withCatalog(config.getCatalog())
.build();
} catch (final SQLException e) {
try {
Expand Down Expand Up @@ -169,6 +170,7 @@ public Properties getClientInfo() {

@Override
public void close() throws SQLException {
clientHandler.close();
if (executorService != null) {
executorService.shutdown();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.apache.arrow.driver.jdbc.client.utils.ClientAuthenticationUtils;
import org.apache.arrow.flight.CallOption;
import org.apache.arrow.flight.CloseSessionRequest;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightClientMiddleware;
import org.apache.arrow.flight.FlightEndpoint;
Expand All @@ -38,6 +40,10 @@
import org.apache.arrow.flight.FlightStatusCode;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.LocationSchemes;
import org.apache.arrow.flight.SessionOptionValue;
import org.apache.arrow.flight.SessionOptionValueFactory;
import org.apache.arrow.flight.SetSessionOptionsRequest;
import org.apache.arrow.flight.SetSessionOptionsResult;
import org.apache.arrow.flight.auth2.BearerCredentialWriter;
import org.apache.arrow.flight.auth2.ClientBearerHeaderHandler;
import org.apache.arrow.flight.auth2.ClientIncomingAuthHeaderMiddleware;
Expand All @@ -56,23 +62,31 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.base.Strings;
import com.google.common.collect.ImmutableMap;

/**
* A {@link FlightSqlClient} handler.
*/
public final class ArrowFlightSqlClientHandler implements AutoCloseable {
private static final Logger LOGGER = LoggerFactory.getLogger(ArrowFlightSqlClientHandler.class);
// JDBC connection string query parameter
private static final String CATALOG = "catalog";

private final FlightSqlClient sqlClient;
private final Set<CallOption> options = new HashSet<>();
private final Builder builder;
private final String catalog;

ArrowFlightSqlClientHandler(final FlightSqlClient sqlClient,
final Builder builder,
final Collection<CallOption> credentialOptions) {
final Collection<CallOption> credentialOptions,
final String catalog) {
this.options.addAll(builder.options);
this.options.addAll(credentialOptions);
this.sqlClient = Preconditions.checkNotNull(sqlClient);
this.builder = builder;
this.catalog = catalog;
}

/**
Expand All @@ -84,8 +98,9 @@ public final class ArrowFlightSqlClientHandler implements AutoCloseable {
*/
public static ArrowFlightSqlClientHandler createNewHandler(final FlightClient client,
final Builder builder,
final Collection<CallOption> options) {
return new ArrowFlightSqlClientHandler(new FlightSqlClient(client), builder, options);
final Collection<CallOption> options,
final String catalog) {
return new ArrowFlightSqlClientHandler(new FlightSqlClient(client), builder, options, catalog);
}

/**
Expand Down Expand Up @@ -194,13 +209,20 @@ public FlightInfo getInfo(final String query) {

@Override
public void close() throws SQLException {
if (hasCatalog()) {
sqlClient.closeSession(new CloseSessionRequest(), getOptions());
}
try {
AutoCloseables.close(sqlClient);
} catch (final Exception e) {
throw new SQLException("Failed to clean up client resources.", e);
}
}

private boolean hasCatalog() {
return !Strings.isNullOrEmpty(catalog);
}

/**
* A prepared statement handler.
*/
Expand Down Expand Up @@ -254,6 +276,21 @@ public interface PreparedStatement extends AutoCloseable {
* @return a new prepared statement.
*/
public PreparedStatement prepare(final String query) {
if (hasCatalog()) {
final SetSessionOptionsRequest setSessionOptionRequest =
new SetSessionOptionsRequest(ImmutableMap.<String, SessionOptionValue>builder()
.put(CATALOG, SessionOptionValueFactory.makeSessionOptionValue(catalog))
.build());
final SetSessionOptionsResult result = sqlClient.setSessionOptions(setSessionOptionRequest, getOptions());
if (result.hasErrors()) {
Map<String, SetSessionOptionsResult.Error> errors = result.getErrors();
for (Map.Entry<String, SetSessionOptionsResult.Error> error : errors.entrySet()) {
LOGGER.warn(error.toString());
}
throw new RuntimeException(String.format("Cannot set session option for catalog = %s", catalog));
}
}

final FlightSqlClient.PreparedStatement preparedStatement =
sqlClient.prepare(query, getOptions());
return new PreparedStatement() {
Expand Down Expand Up @@ -495,6 +532,9 @@ public static final class Builder {
@VisibleForTesting
boolean retainAuth = true;

@VisibleForTesting
String catalog;

// These two middleware are for internal use within build() and should not be exposed by builder APIs.
// Note that these middleware may not necessarily be registered.
@VisibleForTesting
Expand Down Expand Up @@ -530,6 +570,7 @@ public Builder() {
this.clientCertificatePath = original.clientCertificatePath;
this.clientKeyPath = original.clientKeyPath;
this.allocator = original.allocator;
this.catalog = original.catalog;

if (original.retainCookies) {
this.cookieFactory = original.cookieFactory;
Expand Down Expand Up @@ -761,6 +802,16 @@ public Builder withCallOptions(final Collection<CallOption> options) {
return this;
}

/**
* Sets the catalog for this handler.
* @param catalog the catalog
* @return this instance.
*/
public Builder withCatalog(final String catalog) {
this.catalog = catalog;
return this;
}

/**
* Builds a new {@link ArrowFlightSqlClientHandler} from the provided fields.
*
Expand Down Expand Up @@ -833,7 +884,7 @@ public ArrowFlightSqlClientHandler build() throws SQLException {
client, new CredentialCallOption(new BearerCredentialWriter(token)), options.toArray(
new CallOption[0])));
}
return ArrowFlightSqlClientHandler.createNewHandler(client, this, credentialOptions);
return ArrowFlightSqlClientHandler.createNewHandler(client, this, credentialOptions, catalog);

} catch (final IllegalArgumentException | GeneralSecurityException | IOException | FlightRuntimeException e) {
final SQLException originalException = new SQLException(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,14 @@ public boolean retainAuth() {
return ArrowFlightConnectionProperty.RETAIN_AUTH.getBoolean(properties);
}

/**
* The catalog to which a connection is made.
* @return the catalog.
*/
public String getCatalog() {
return ArrowFlightConnectionProperty.CATALOG.getString(properties);
}

/**
* Gets the {@link CallOption}s from this {@link ConnectionConfig}.
*
Expand Down Expand Up @@ -210,7 +218,8 @@ public enum ArrowFlightConnectionProperty implements ConnectionProperty {
THREAD_POOL_SIZE("threadPoolSize", 1, Type.NUMBER, false),
TOKEN("token", null, Type.STRING, false),
RETAIN_COOKIES("retainCookies", true, Type.BOOLEAN, false),
RETAIN_AUTH("retainAuth", true, Type.BOOLEAN, false);
RETAIN_AUTH("retainAuth", true, Type.BOOLEAN, false),
CATALOG("catalog", null, Type.STRING, false);

private final String camelName;
private final Object defaultValue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import static java.lang.Runtime.getRuntime;
import static java.util.Arrays.asList;
import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.CATALOG;
import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.HOST;
import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.PASSWORD;
import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.PORT;
Expand Down Expand Up @@ -91,6 +92,8 @@ public static List<Object[]> provideParameters() {
{THREAD_POOL_SIZE,
RANDOM.nextInt(getRuntime().availableProcessors()),
(Function<ArrowFlightConnectionConfigImpl, ?>) ArrowFlightConnectionConfigImpl::threadPoolSize},
{CATALOG, "catalog",
(Function<ArrowFlightConnectionConfigImpl, ?>) ArrowFlightConnectionConfigImpl::getCatalog},
});
}
}

0 comments on commit f59d7c7

Please sign in to comment.