From b8a6e0795a2fa0b370c52d198d15a20ad73b518b Mon Sep 17 00:00:00 2001 From: Steve Lord Date: Tue, 16 Apr 2024 13:41:29 -0700 Subject: [PATCH] GH 37720:[Java][FlightSQL] Implement stateless prepared statement Code clean up --- .../example/FlightSqlStatelessExample.java | 81 ++++++++++--------- .../sql/test/TestFlightSqlStateless.java | 13 ++- 2 files changed, 54 insertions(+), 40 deletions(-) diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlStatelessExample.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlStatelessExample.java index 336f44a203a46..3c522528512ac 100644 --- a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlStatelessExample.java +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlStatelessExample.java @@ -91,10 +91,8 @@ public Runnable acceptPutPreparedStatementQuery(CommandPreparedStatementQuery co return () -> { try { - final String query = new String(command.getPreparedStatementHandle().toString("UTF-8")); - final Connection connection = dataSource.getConnection(); - final PreparedStatement preparedStatement = connection.prepareStatement(query, - ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY); + final String query = new String(command.getPreparedStatementHandle().toStringUtf8()); + final PreparedStatement preparedStatement = createPreparedStatement(query); while (flightStream.next()) { final VectorSchemaRoot root = flightStream.getRoot(); @@ -114,20 +112,16 @@ public Runnable acceptPutPreparedStatementQuery(CommandPreparedStatementQuery co final DoPutPreparedStatementResultPOJO doPutPreparedStatementResultPOJO = new DoPutPreparedStatementResultPOJO(query, parametersStream.toByteArray()); - try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); - ObjectOutputStream oos = new ObjectOutputStream(bos)) { - oos.writeObject(doPutPreparedStatementResultPOJO); - final byte[] doPutPreparedStatementResultPOJOArr = bos.toByteArray(); - final DoPutPreparedStatementResult doPutPreparedStatementResult = - DoPutPreparedStatementResult.newBuilder() - .setPreparedStatementHandle( - ByteString.copyFrom(ByteBuffer.wrap(doPutPreparedStatementResultPOJOArr))) - .build(); - - try (final ArrowBuf buffer = rootAllocator.buffer(doPutPreparedStatementResult.getSerializedSize())) { - buffer.writeBytes(doPutPreparedStatementResult.toByteArray()); - ackStream.onNext(PutResult.metadata(buffer)); - } + final byte[] doPutPreparedStatementResultPOJOArr = serializePOJO(doPutPreparedStatementResultPOJO); + final DoPutPreparedStatementResult doPutPreparedStatementResult = + DoPutPreparedStatementResult.newBuilder() + .setPreparedStatementHandle( + ByteString.copyFrom(ByteBuffer.wrap(doPutPreparedStatementResultPOJOArr))) + .build(); + + try (final ArrowBuf buffer = rootAllocator.buffer(doPutPreparedStatementResult.getSerializedSize())) { + buffer.writeBytes(doPutPreparedStatementResult.toByteArray()); + ackStream.onNext(PutResult.metadata(buffer)); } } } @@ -150,14 +144,11 @@ public void getStreamPreparedStatement(final CommandPreparedStatementQuery comma try { // Case where there are parameters final byte[] handle = command.getPreparedStatementHandle().toByteArray(); - try (ByteArrayInputStream bis = new ByteArrayInputStream(handle); - ObjectInputStream ois = new ObjectInputStream(bis)) { + try { final DoPutPreparedStatementResultPOJO doPutPreparedStatementResultPOJO = - (DoPutPreparedStatementResultPOJO) ois.readObject(); + deserializePOJO(handle); final String query = doPutPreparedStatementResultPOJO.getQuery(); - final Connection connection = dataSource.getConnection(); - final PreparedStatement statement = connection.prepareStatement(query, - ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY); + final PreparedStatement statement = createPreparedStatement(query); try (ArrowFileReader reader = new ArrowFileReader(new SeekableReadChannel( new ByteArrayReadableSeekableByteChannel( @@ -176,10 +167,8 @@ public void getStreamPreparedStatement(final CommandPreparedStatementQuery comma } } catch (StreamCorruptedException e) { // Case where there are no parameters - final String query = new String(command.getPreparedStatementHandle().toString("UTF-8")); - final Connection connection = dataSource.getConnection(); - final PreparedStatement preparedStatement = connection.prepareStatement(query, - ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY); + final String query = new String(command.getPreparedStatementHandle().toStringUtf8()); + final PreparedStatement preparedStatement = createPreparedStatement(query); executeQuery(preparedStatement, listener); } } catch (final SQLException | IOException | ClassNotFoundException e) { @@ -218,20 +207,16 @@ private void executeQuery(PreparedStatement statement, public FlightInfo getFlightInfoPreparedStatement(final CommandPreparedStatementQuery command, final CallContext context, final FlightDescriptor descriptor) { - String query = null; + String query; try { final byte[] handle = command.getPreparedStatementHandle().toByteArray(); - try (ByteArrayInputStream bis = new ByteArrayInputStream(handle); - ObjectInputStream ois = new ObjectInputStream(bis)) { - final DoPutPreparedStatementResultPOJO doPutPreparedStatementResultPOJO = - (DoPutPreparedStatementResultPOJO) ois.readObject(); - query = doPutPreparedStatementResultPOJO.getQuery(); + final DoPutPreparedStatementResultPOJO doPutPreparedStatementResultPOJO = null; + try { + query = deserializePOJO(handle).getQuery(); } catch (StreamCorruptedException e) { - query = new String(command.getPreparedStatementHandle().toString("UTF-8")); + query = new String(command.getPreparedStatementHandle().toStringUtf8()); } - final Connection connection = dataSource.getConnection(); - final PreparedStatement statement = connection.prepareStatement(query, - ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY); + final PreparedStatement statement = createPreparedStatement(query); ResultSetMetaData metaData = statement.getMetaData(); return getFlightInfoForSchema(command, descriptor, @@ -243,4 +228,24 @@ public FlightInfo getFlightInfoPreparedStatement(final CommandPreparedStatementQ throw CallStatus.INTERNAL.withCause(e).toRuntimeException(); } } + + private DoPutPreparedStatementResultPOJO deserializePOJO(byte[] handle) throws IOException, ClassNotFoundException { + try (ByteArrayInputStream bis = new ByteArrayInputStream(handle); + ObjectInputStream ois = new ObjectInputStream(bis)) { + return (DoPutPreparedStatementResultPOJO) ois.readObject(); + } + } + + private byte[] serializePOJO(DoPutPreparedStatementResultPOJO doPutPreparedStatementResultPOJO) throws IOException { + try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutputStream oos = new ObjectOutputStream(bos)) { + oos.writeObject(doPutPreparedStatementResultPOJO); + return bos.toByteArray(); + } + } + + private PreparedStatement createPreparedStatement(String query) throws SQLException { + final Connection connection = dataSource.getConnection(); + return connection.prepareStatement(query, ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY); + } } diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSqlStateless.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSqlStateless.java index b9ef6f5ee56b4..77f140da0e32f 100644 --- a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSqlStateless.java +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSqlStateless.java @@ -18,6 +18,7 @@ package org.apache.arrow.flight.sql.test; import static org.apache.arrow.flight.sql.util.FlightStreamUtils.getResults; +import static org.apache.arrow.util.AutoCloseables.close; import static org.hamcrest.CoreMatchers.*; import org.apache.arrow.flight.FlightClient; @@ -33,6 +34,7 @@ import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.types.pojo.Schema; import org.hamcrest.MatcherAssert; +import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -48,6 +50,11 @@ public static void setUp() throws Exception { setUpExpectedResultsMap(); } + @AfterAll + public static void tearDown() throws Exception { + close(sqlClient, server, allocator); + } + private static void setUpClientServer() throws Exception { allocator = new RootAllocator(Integer.MAX_VALUE); @@ -72,12 +79,14 @@ public void testSimplePreparedStatementResultsWithParameterBinding() throws Exce insertRoot.setRowCount(1); prepare.setParameters(insertRoot); - FlightInfo flightInfo = prepare.execute(); + final FlightInfo flightInfo = prepare.execute(); - FlightStream stream = sqlClient.getStream(flightInfo + final FlightStream stream = sqlClient.getStream(flightInfo .getEndpoints() .get(0).getTicket()); + // TODO: root is null and getSchema hangs when run as complete suite. + // This works when run as an individual test. Assertions.assertAll( () -> MatcherAssert.assertThat(stream.getSchema(), is(SCHEMA_INT_TABLE)), () -> MatcherAssert.assertThat(getResults(stream), is(EXPECTED_RESULTS_FOR_PARAMETER_BINDING))