From 7fb894e8e1bc32d4719fc7d707b486b2e22b0120 Mon Sep 17 00:00:00 2001 From: Steve Lord Date: Fri, 12 Apr 2024 16:24:13 -0700 Subject: [PATCH] Added tests for stateless server --- .../DoPutPreparedStatementResultPOJO.java | 38 ++++ .../flight/sql/example/FlightSqlExample.java | 39 +--- .../example/FlightSqlStatelessExample.java | 198 ++++++++++++++++++ .../arrow/flight/sql/test/TestFlightSql.java | 59 +++--- .../sql/test/TestFlightSqlStateless.java | 88 ++++++++ 5 files changed, 363 insertions(+), 59 deletions(-) create mode 100644 java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/DoPutPreparedStatementResultPOJO.java create mode 100644 java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlStatelessExample.java create mode 100644 java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSqlStateless.java diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/DoPutPreparedStatementResultPOJO.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/DoPutPreparedStatementResultPOJO.java new file mode 100644 index 0000000000000..f85b06269443d --- /dev/null +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/DoPutPreparedStatementResultPOJO.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.sql.example; + +import java.io.Serializable; + +public class DoPutPreparedStatementResultPOJO implements Serializable { + private transient String query; + private transient byte[] parameters; + + public DoPutPreparedStatementResultPOJO(String query, byte[] parameters) { + this.query = query; + this.parameters = parameters.clone(); + } + + public String getQuery() { + return null; + } + + public byte[] getParameters() { + return null; + } +} diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java index 77875d2d58563..61c44cad73261 100644 --- a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java @@ -124,7 +124,6 @@ import org.apache.arrow.vector.VectorUnloader; import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.impl.UnionListWriter; -import org.apache.arrow.vector.ipc.ArrowFileWriter; import org.apache.arrow.vector.ipc.WriteChannel; import org.apache.arrow.vector.ipc.message.MessageSerializer; import org.apache.arrow.vector.types.Types.MinorType; @@ -159,12 +158,12 @@ public class FlightSqlExample implements FlightSqlProducer, AutoCloseable { private static final String DATABASE_URI = "jdbc:derby:target/derbyDB"; private static final Logger LOGGER = getLogger(FlightSqlExample.class); - private static final Calendar DEFAULT_CALENDAR = JdbcToArrowUtils.getUtcCalendar(); + protected static final Calendar DEFAULT_CALENDAR = JdbcToArrowUtils.getUtcCalendar(); // ARROW-15315: Use ExecutorService to simulate an async scenario private final ExecutorService executorService = Executors.newFixedThreadPool(10); private final Location location; - private final PoolingDataSource dataSource; - private final BufferAllocator rootAllocator = new RootAllocator(); + protected final PoolingDataSource dataSource; + protected final BufferAllocator rootAllocator = new RootAllocator(); private final Cache> preparedStatementLoadingCache; private final Cache> statementLoadingCache; private final SqlInfoBuilder sqlInfoBuilder; @@ -779,7 +778,7 @@ public void createPreparedStatement(final ActionCreatePreparedStatementRequest r // Running on another thread Future unused = executorService.submit(() -> { try { - final ByteString preparedStatementHandle = copyFrom(randomUUID().toString().getBytes(UTF_8)); + final ByteString preparedStatementHandle = copyFrom(request.getQuery().getBytes(UTF_8)); // Ownership of the connection will be passed to the context. Do NOT close! final Connection connection = dataSource.getConnection(); final PreparedStatement preparedStatement = connection.prepareStatement(request.getQuery(), @@ -912,43 +911,17 @@ public Runnable acceptPutPreparedStatementQuery(CommandPreparedStatementQuery co return () -> { assert statementContext != null; PreparedStatement preparedStatement = statementContext.getStatement(); - JdbcParameterBinder binder = null; try { while (flightStream.next()) { final VectorSchemaRoot root = flightStream.getRoot(); - binder = JdbcParameterBinder.builder(preparedStatement, root).bindAll().build(); + final JdbcParameterBinder binder = JdbcParameterBinder.builder(preparedStatement, root).bindAll().build(); while (binder.next()) { // Do not execute() - will be done in a getStream call } - final ByteArrayOutputStream out = new ByteArrayOutputStream(); - try ( - ArrowFileWriter writer = new ArrowFileWriter(root, null, Channels.newChannel(out)) - ) { - writer.start(); - writer.writeBatch(); - } - if (out.size() > 0) { - final DoPutPreparedStatementResult doPutPreparedStatementResult = - DoPutPreparedStatementResult.newBuilder() - .setPreparedStatementHandle(ByteString.copyFrom(ByteBuffer.wrap(out.toByteArray()))) - .build(); - - // Update prepared statement cache by storing with new handle and remove old entry. - preparedStatementLoadingCache.put(doPutPreparedStatementResult.getPreparedStatementHandle(), - statementContext); - // TODO: If we invalidate old cached entry here this invalidates the statement, which is not what is needed. - // We need to re-cache the statementContext with a new key. - // preparedStatementLoadingCache.invalidate(command.getPreparedStatementHandle()); - - try (final ArrowBuf buffer = rootAllocator.buffer(doPutPreparedStatementResult.getSerializedSize())) { - buffer.writeBytes(doPutPreparedStatementResult.toByteArray()); - ackStream.onNext(PutResult.metadata(buffer)); - } - } } - } catch (SQLException | IOException e) { + } catch (SQLException e) { ackStream.onError(CallStatus.INTERNAL .withDescription("Failed to bind parameters: " + e.getMessage()) .withCause(e) diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlStatelessExample.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlStatelessExample.java new file mode 100644 index 0000000000000..65db3a90e2769 --- /dev/null +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlStatelessExample.java @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.sql.example; + +import static java.lang.String.format; +import static org.apache.arrow.adapter.jdbc.JdbcToArrow.sqlToArrowVectorIterator; +import static org.apache.arrow.adapter.jdbc.JdbcToArrowUtils.jdbcToArrowSchema; +import static org.apache.arrow.flight.sql.impl.FlightSql.*; +import static org.slf4j.LoggerFactory.getLogger; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; + +import org.apache.arrow.adapter.jdbc.ArrowVectorIterator; +import org.apache.arrow.adapter.jdbc.JdbcParameterBinder; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.PutResult; +import org.apache.arrow.flight.sql.FlightSqlProducer; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.ipc.ArrowFileReader; +import org.apache.arrow.vector.ipc.ArrowFileWriter; +import org.apache.arrow.vector.ipc.SeekableReadChannel; +import org.apache.arrow.vector.ipc.message.ArrowBlock; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel; +import org.slf4j.Logger; + +import com.google.protobuf.ByteString; + +/** + * Example {@link FlightSqlProducer} implementation showing an Apache Derby backed Flight SQL server that generally + * supports all current features of Flight SQL. + */ +public class FlightSqlStatelessExample extends FlightSqlExample { + private static final Logger LOGGER = getLogger(FlightSqlStatelessExample.class); + + public static void main(String[] args) throws Exception { + Location location = Location.forGrpcInsecure("localhost", 55555); + final FlightSqlStatelessExample example = new FlightSqlStatelessExample(location); + Location listenLocation = Location.forGrpcInsecure("0.0.0.0", 55555); + try (final BufferAllocator allocator = new RootAllocator(); + final FlightServer server = FlightServer.builder(allocator, listenLocation, example).build()) { + server.start(); + server.awaitTermination(); + } + } + + public FlightSqlStatelessExample(final Location location) { + super(location); + } + + @Override + public Runnable acceptPutPreparedStatementQuery(CommandPreparedStatementQuery command, CallContext context, + FlightStream flightStream, StreamListener ackStream) { + + return () -> { + try { + final String query = new String(command.getPreparedStatementHandle().toString("UTF-8")); + final Connection connection = dataSource.getConnection(); + final PreparedStatement preparedStatement = connection.prepareStatement(query, + ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY); + + while (flightStream.next()) { + final VectorSchemaRoot root = flightStream.getRoot(); + final JdbcParameterBinder binder = JdbcParameterBinder.builder(preparedStatement, root).bindAll().build(); + while (binder.next()) { + // Do not execute() - will be done in a getStream call + } + + final ByteArrayOutputStream parametersStream = new ByteArrayOutputStream(); + try (ArrowFileWriter writer = new ArrowFileWriter(root, null, Channels.newChannel(parametersStream)) + ) { + writer.start(); + writer.writeBatch(); + } + + if (parametersStream.size() > 0) { + final DoPutPreparedStatementResultPOJO doPutPreparedStatementResultPOJO = + new DoPutPreparedStatementResultPOJO(query, parametersStream.toByteArray()); + try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutputStream oos = new ObjectOutputStream(bos)) { + oos.writeObject((Object) doPutPreparedStatementResultPOJO); + final byte[] doPutPreparedStatementResultPOJOArr = bos.toByteArray(); + final DoPutPreparedStatementResult doPutPreparedStatementResult = + DoPutPreparedStatementResult.newBuilder() + .setPreparedStatementHandle( + ByteString.copyFrom(ByteBuffer.wrap(doPutPreparedStatementResultPOJOArr))) + .build(); + + try (final ArrowBuf buffer = rootAllocator.buffer(doPutPreparedStatementResult.getSerializedSize())) { + buffer.writeBytes(doPutPreparedStatementResult.toByteArray()); + ackStream.onNext(PutResult.metadata(buffer)); + } + } + } + } + + } catch (SQLException | IOException e) { + ackStream.onError(CallStatus.INTERNAL + .withDescription("Failed to bind parameters: " + e.getMessage()) + .withCause(e) + .toRuntimeException()); + return; + } + + ackStream.onCompleted(); + }; + } + + @Override + public void getStreamPreparedStatement(final CommandPreparedStatementQuery command, final CallContext context, + final ServerStreamListener listener) { + try { + final byte[] handle = command.getPreparedStatementHandle().asReadOnlyByteBuffer().array(); + try (ByteArrayInputStream bis = new ByteArrayInputStream(handle); + ObjectInputStream ois = new ObjectInputStream(bis)) { + final DoPutPreparedStatementResultPOJO doPutPreparedStatementResultPOJO = + (DoPutPreparedStatementResultPOJO) ois.readObject(); + final String query = doPutPreparedStatementResultPOJO.getQuery(); + final Connection connection = dataSource.getConnection(); + final PreparedStatement statement = connection.prepareStatement(query, + ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY); + + try (ArrowFileReader reader = new ArrowFileReader(new SeekableReadChannel( + new ByteArrayReadableSeekableByteChannel( + doPutPreparedStatementResultPOJO.getParameters())), rootAllocator)) { + + for (ArrowBlock arrowBlock : reader.getRecordBlocks()) { + reader.loadRecordBatch(arrowBlock); + VectorSchemaRoot vectorSchemaRootRecover = reader.getVectorSchemaRoot(); + JdbcParameterBinder binder = JdbcParameterBinder.builder(statement, vectorSchemaRootRecover) + .bindAll().build(); + + while (binder.next()) { + try (final ResultSet resultSet = statement.executeQuery()) { + final Schema schema = jdbcToArrowSchema(resultSet.getMetaData(), DEFAULT_CALENDAR); + try (final VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.create(schema, rootAllocator)) { + final VectorLoader loader = new VectorLoader(vectorSchemaRoot); + listener.start(vectorSchemaRoot); + + final ArrowVectorIterator iterator = sqlToArrowVectorIterator(resultSet, rootAllocator); + while (iterator.hasNext()) { + final VectorSchemaRoot batch = iterator.next(); + if (batch.getRowCount() == 0) { + break; + } + final VectorUnloader unloader = new VectorUnloader(batch); + loader.load(unloader.getRecordBatch()); + listener.putNext(); + vectorSchemaRoot.clear(); + } + listener.putNext(); + } + } + } + } + } + } + } catch (final SQLException | IOException | ClassNotFoundException e) { + LOGGER.error(format("Failed to getStreamPreparedStatement: <%s>.", e.getMessage()), e); + listener.error(CallStatus.INTERNAL.withDescription("Failed to prepare statement: " + e).toRuntimeException()); + } finally { + listener.completed(); + } + } +} diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSql.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSql.java index a39736e939f0b..fdfcb38116bb0 100644 --- a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSql.java +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSql.java @@ -87,58 +87,65 @@ public class TestFlightSql { Field.nullable("FOREIGNID", MinorType.INT.getType()))); private static final List> EXPECTED_RESULTS_FOR_STAR_SELECT_QUERY = ImmutableList.of( asList("1", "one", "1", "1"), asList("2", "zero", "0", "1"), asList("3", "negative one", "-1", "1")); - private static final List> EXPECTED_RESULTS_FOR_PARAMETER_BINDING = ImmutableList.of( + protected static final List> EXPECTED_RESULTS_FOR_PARAMETER_BINDING = ImmutableList.of( asList("1", "one", "1", "1")); private static final Map GET_SQL_INFO_EXPECTED_RESULTS_MAP = new LinkedHashMap<>(); - private static final String LOCALHOST = "localhost"; - private static BufferAllocator allocator; - private static FlightServer server; - private static FlightSqlClient sqlClient; + protected static final String LOCALHOST = "localhost"; + protected static BufferAllocator allocator; + protected static FlightServer server; + protected static FlightSqlClient sqlClient; @BeforeAll public static void setUp() throws Exception { + setUpClientServer(); + setUpExpectedResultsMap(); + } + + private static void setUpClientServer() throws Exception { allocator = new RootAllocator(Integer.MAX_VALUE); final Location serverLocation = Location.forGrpcInsecure(LOCALHOST, 0); server = FlightServer.builder(allocator, serverLocation, new FlightSqlExample(serverLocation)) - .build() - .start(); + .build() + .start(); final Location clientLocation = Location.forGrpcInsecure(LOCALHOST, server.getPort()); sqlClient = new FlightSqlClient(FlightClient.builder(allocator, clientLocation).build()); + } + protected static void setUpExpectedResultsMap() { GET_SQL_INFO_EXPECTED_RESULTS_MAP - .put(Integer.toString(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME_VALUE), "Apache Derby"); + .put(Integer.toString(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME_VALUE), "Apache Derby"); GET_SQL_INFO_EXPECTED_RESULTS_MAP - .put(Integer.toString(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_VERSION_VALUE), "10.14.2.0 - (1828579)"); + .put(Integer.toString(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_VERSION_VALUE), "10.14.2.0 - (1828579)"); GET_SQL_INFO_EXPECTED_RESULTS_MAP - .put(Integer.toString(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_ARROW_VERSION_VALUE), "10.14.2.0 - (1828579)"); + .put(Integer.toString(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_ARROW_VERSION_VALUE), "10.14.2.0 - (1828579)"); GET_SQL_INFO_EXPECTED_RESULTS_MAP - .put(Integer.toString(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY_VALUE), "false"); + .put(Integer.toString(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY_VALUE), "false"); GET_SQL_INFO_EXPECTED_RESULTS_MAP - .put(Integer.toString(FlightSql.SqlInfo.SQL_ALL_TABLES_ARE_SELECTABLE_VALUE), "true"); + .put(Integer.toString(FlightSql.SqlInfo.SQL_ALL_TABLES_ARE_SELECTABLE_VALUE), "true"); GET_SQL_INFO_EXPECTED_RESULTS_MAP - .put( - Integer.toString(FlightSql.SqlInfo.SQL_NULL_ORDERING_VALUE), - Integer.toString(FlightSql.SqlNullOrdering.SQL_NULLS_SORTED_AT_END_VALUE)); + .put( + Integer.toString(FlightSql.SqlInfo.SQL_NULL_ORDERING_VALUE), + Integer.toString(FlightSql.SqlNullOrdering.SQL_NULLS_SORTED_AT_END_VALUE)); GET_SQL_INFO_EXPECTED_RESULTS_MAP - .put(Integer.toString(FlightSql.SqlInfo.SQL_DDL_CATALOG_VALUE), "false"); + .put(Integer.toString(FlightSql.SqlInfo.SQL_DDL_CATALOG_VALUE), "false"); GET_SQL_INFO_EXPECTED_RESULTS_MAP - .put(Integer.toString(FlightSql.SqlInfo.SQL_DDL_SCHEMA_VALUE), "true"); + .put(Integer.toString(FlightSql.SqlInfo.SQL_DDL_SCHEMA_VALUE), "true"); GET_SQL_INFO_EXPECTED_RESULTS_MAP - .put(Integer.toString(FlightSql.SqlInfo.SQL_DDL_TABLE_VALUE), "true"); + .put(Integer.toString(FlightSql.SqlInfo.SQL_DDL_TABLE_VALUE), "true"); GET_SQL_INFO_EXPECTED_RESULTS_MAP - .put( - Integer.toString(FlightSql.SqlInfo.SQL_IDENTIFIER_CASE_VALUE), - Integer.toString(SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_UPPERCASE_VALUE)); + .put( + Integer.toString(FlightSql.SqlInfo.SQL_IDENTIFIER_CASE_VALUE), + Integer.toString(SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_UPPERCASE_VALUE)); GET_SQL_INFO_EXPECTED_RESULTS_MAP - .put(Integer.toString(FlightSql.SqlInfo.SQL_IDENTIFIER_QUOTE_CHAR_VALUE), "\""); + .put(Integer.toString(FlightSql.SqlInfo.SQL_IDENTIFIER_QUOTE_CHAR_VALUE), "\""); GET_SQL_INFO_EXPECTED_RESULTS_MAP - .put( - Integer.toString(FlightSql.SqlInfo.SQL_QUOTED_IDENTIFIER_CASE_VALUE), - Integer.toString(SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_CASE_INSENSITIVE_VALUE)); + .put( + Integer.toString(FlightSql.SqlInfo.SQL_QUOTED_IDENTIFIER_CASE_VALUE), + Integer.toString(SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_CASE_INSENSITIVE_VALUE)); GET_SQL_INFO_EXPECTED_RESULTS_MAP - .put(Integer.toString(FlightSql.SqlInfo.SQL_MAX_COLUMNS_IN_TABLE_VALUE), "42"); + .put(Integer.toString(FlightSql.SqlInfo.SQL_MAX_COLUMNS_IN_TABLE_VALUE), "42"); } @AfterAll diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSqlStateless.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSqlStateless.java new file mode 100644 index 0000000000000..b9ef6f5ee56b4 --- /dev/null +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSqlStateless.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.sql.test; + +import static org.apache.arrow.flight.sql.util.FlightStreamUtils.getResults; +import static org.hamcrest.CoreMatchers.*; + +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.sql.FlightSqlClient; +import org.apache.arrow.flight.sql.FlightSqlClient.PreparedStatement; +import org.apache.arrow.flight.sql.example.FlightSqlStatelessExample; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.Schema; +import org.hamcrest.MatcherAssert; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +/** + * Test direct usage of Flight SQL workflows. + */ +public class TestFlightSqlStateless extends TestFlightSql { + + @BeforeAll + public static void setUp() throws Exception { + setUpClientServer(); + setUpExpectedResultsMap(); + } + + private static void setUpClientServer() throws Exception { + allocator = new RootAllocator(Integer.MAX_VALUE); + + final Location serverLocation = Location.forGrpcInsecure(LOCALHOST, 0); + server = FlightServer.builder(allocator, serverLocation, new FlightSqlStatelessExample(serverLocation)) + .build() + .start(); + + final Location clientLocation = Location.forGrpcInsecure(LOCALHOST, server.getPort()); + sqlClient = new FlightSqlClient(FlightClient.builder(allocator, clientLocation).build()); + } + + @Test + public void testSimplePreparedStatementResultsWithParameterBinding() throws Exception { + try (PreparedStatement prepare = sqlClient.prepare("SELECT * FROM intTable WHERE id = ?")) { + final Schema parameterSchema = prepare.getParameterSchema(); + try (final VectorSchemaRoot insertRoot = VectorSchemaRoot.create(parameterSchema, allocator)) { + insertRoot.allocateNew(); + + final IntVector valueVector = (IntVector) insertRoot.getVector(0); + valueVector.setSafe(0, 1); + insertRoot.setRowCount(1); + + prepare.setParameters(insertRoot); + FlightInfo flightInfo = prepare.execute(); + + FlightStream stream = sqlClient.getStream(flightInfo + .getEndpoints() + .get(0).getTicket()); + + Assertions.assertAll( + () -> MatcherAssert.assertThat(stream.getSchema(), is(SCHEMA_INT_TABLE)), + () -> MatcherAssert.assertThat(getResults(stream), is(EXPECTED_RESULTS_FOR_PARAMETER_BINDING)) + ); + } + } + } +}