Skip to content

Commit

Permalink
GH 37720:[Java][FlightSQL] Implement stateless prepared statement
Browse files Browse the repository at this point in the history
Code clean up
  • Loading branch information
stevelorddremio committed Apr 19, 2024
1 parent 453d2ac commit b8a6e07
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,8 @@ public Runnable acceptPutPreparedStatementQuery(CommandPreparedStatementQuery co

return () -> {
try {
final String query = new String(command.getPreparedStatementHandle().toString("UTF-8"));
final Connection connection = dataSource.getConnection();
final PreparedStatement preparedStatement = connection.prepareStatement(query,
ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY);
final String query = new String(command.getPreparedStatementHandle().toStringUtf8());
final PreparedStatement preparedStatement = createPreparedStatement(query);

while (flightStream.next()) {
final VectorSchemaRoot root = flightStream.getRoot();
Expand All @@ -114,20 +112,16 @@ public Runnable acceptPutPreparedStatementQuery(CommandPreparedStatementQuery co
final DoPutPreparedStatementResultPOJO doPutPreparedStatementResultPOJO =
new DoPutPreparedStatementResultPOJO(query, parametersStream.toByteArray());

try (ByteArrayOutputStream bos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(bos)) {
oos.writeObject(doPutPreparedStatementResultPOJO);
final byte[] doPutPreparedStatementResultPOJOArr = bos.toByteArray();
final DoPutPreparedStatementResult doPutPreparedStatementResult =
DoPutPreparedStatementResult.newBuilder()
.setPreparedStatementHandle(
ByteString.copyFrom(ByteBuffer.wrap(doPutPreparedStatementResultPOJOArr)))
.build();

try (final ArrowBuf buffer = rootAllocator.buffer(doPutPreparedStatementResult.getSerializedSize())) {
buffer.writeBytes(doPutPreparedStatementResult.toByteArray());
ackStream.onNext(PutResult.metadata(buffer));
}
final byte[] doPutPreparedStatementResultPOJOArr = serializePOJO(doPutPreparedStatementResultPOJO);
final DoPutPreparedStatementResult doPutPreparedStatementResult =
DoPutPreparedStatementResult.newBuilder()
.setPreparedStatementHandle(
ByteString.copyFrom(ByteBuffer.wrap(doPutPreparedStatementResultPOJOArr)))
.build();

try (final ArrowBuf buffer = rootAllocator.buffer(doPutPreparedStatementResult.getSerializedSize())) {
buffer.writeBytes(doPutPreparedStatementResult.toByteArray());
ackStream.onNext(PutResult.metadata(buffer));
}
}
}
Expand All @@ -150,14 +144,11 @@ public void getStreamPreparedStatement(final CommandPreparedStatementQuery comma
try {
// Case where there are parameters
final byte[] handle = command.getPreparedStatementHandle().toByteArray();
try (ByteArrayInputStream bis = new ByteArrayInputStream(handle);
ObjectInputStream ois = new ObjectInputStream(bis)) {
try {
final DoPutPreparedStatementResultPOJO doPutPreparedStatementResultPOJO =
(DoPutPreparedStatementResultPOJO) ois.readObject();
deserializePOJO(handle);
final String query = doPutPreparedStatementResultPOJO.getQuery();
final Connection connection = dataSource.getConnection();
final PreparedStatement statement = connection.prepareStatement(query,
ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY);
final PreparedStatement statement = createPreparedStatement(query);

try (ArrowFileReader reader = new ArrowFileReader(new SeekableReadChannel(
new ByteArrayReadableSeekableByteChannel(
Expand All @@ -176,10 +167,8 @@ public void getStreamPreparedStatement(final CommandPreparedStatementQuery comma
}
} catch (StreamCorruptedException e) {
// Case where there are no parameters
final String query = new String(command.getPreparedStatementHandle().toString("UTF-8"));
final Connection connection = dataSource.getConnection();
final PreparedStatement preparedStatement = connection.prepareStatement(query,
ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY);
final String query = new String(command.getPreparedStatementHandle().toStringUtf8());
final PreparedStatement preparedStatement = createPreparedStatement(query);
executeQuery(preparedStatement, listener);
}
} catch (final SQLException | IOException | ClassNotFoundException e) {
Expand Down Expand Up @@ -218,20 +207,16 @@ private void executeQuery(PreparedStatement statement,
public FlightInfo getFlightInfoPreparedStatement(final CommandPreparedStatementQuery command,
final CallContext context,
final FlightDescriptor descriptor) {
String query = null;
String query;
try {
final byte[] handle = command.getPreparedStatementHandle().toByteArray();
try (ByteArrayInputStream bis = new ByteArrayInputStream(handle);
ObjectInputStream ois = new ObjectInputStream(bis)) {
final DoPutPreparedStatementResultPOJO doPutPreparedStatementResultPOJO =
(DoPutPreparedStatementResultPOJO) ois.readObject();
query = doPutPreparedStatementResultPOJO.getQuery();
final DoPutPreparedStatementResultPOJO doPutPreparedStatementResultPOJO = null;
try {
query = deserializePOJO(handle).getQuery();
} catch (StreamCorruptedException e) {
query = new String(command.getPreparedStatementHandle().toString("UTF-8"));
query = new String(command.getPreparedStatementHandle().toStringUtf8());
}
final Connection connection = dataSource.getConnection();
final PreparedStatement statement = connection.prepareStatement(query,
ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY);
final PreparedStatement statement = createPreparedStatement(query);

ResultSetMetaData metaData = statement.getMetaData();
return getFlightInfoForSchema(command, descriptor,
Expand All @@ -243,4 +228,24 @@ public FlightInfo getFlightInfoPreparedStatement(final CommandPreparedStatementQ
throw CallStatus.INTERNAL.withCause(e).toRuntimeException();
}
}

private DoPutPreparedStatementResultPOJO deserializePOJO(byte[] handle) throws IOException, ClassNotFoundException {
try (ByteArrayInputStream bis = new ByteArrayInputStream(handle);
ObjectInputStream ois = new ObjectInputStream(bis)) {
return (DoPutPreparedStatementResultPOJO) ois.readObject();
}
}

private byte[] serializePOJO(DoPutPreparedStatementResultPOJO doPutPreparedStatementResultPOJO) throws IOException {
try (ByteArrayOutputStream bos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(bos)) {
oos.writeObject(doPutPreparedStatementResultPOJO);
return bos.toByteArray();
}
}

private PreparedStatement createPreparedStatement(String query) throws SQLException {
final Connection connection = dataSource.getConnection();
return connection.prepareStatement(query, ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.arrow.flight.sql.test;

import static org.apache.arrow.flight.sql.util.FlightStreamUtils.getResults;
import static org.apache.arrow.util.AutoCloseables.close;
import static org.hamcrest.CoreMatchers.*;

import org.apache.arrow.flight.FlightClient;
Expand All @@ -33,6 +34,7 @@
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.pojo.Schema;
import org.hamcrest.MatcherAssert;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
Expand All @@ -48,6 +50,11 @@ public static void setUp() throws Exception {
setUpExpectedResultsMap();
}

@AfterAll
public static void tearDown() throws Exception {
close(sqlClient, server, allocator);
}

private static void setUpClientServer() throws Exception {
allocator = new RootAllocator(Integer.MAX_VALUE);

Expand All @@ -72,12 +79,14 @@ public void testSimplePreparedStatementResultsWithParameterBinding() throws Exce
insertRoot.setRowCount(1);

prepare.setParameters(insertRoot);
FlightInfo flightInfo = prepare.execute();
final FlightInfo flightInfo = prepare.execute();

FlightStream stream = sqlClient.getStream(flightInfo
final FlightStream stream = sqlClient.getStream(flightInfo
.getEndpoints()
.get(0).getTicket());

// TODO: root is null and getSchema hangs when run as complete suite.
// This works when run as an individual test.
Assertions.assertAll(
() -> MatcherAssert.assertThat(stream.getSchema(), is(SCHEMA_INT_TABLE)),
() -> MatcherAssert.assertThat(getResults(stream), is(EXPECTED_RESULTS_FOR_PARAMETER_BINDING))
Expand Down

0 comments on commit b8a6e07

Please sign in to comment.